Beispiel #1
0
def flattened_index_select(target: torch.Tensor,
                           indices: torch.LongTensor) -> torch.Tensor:
    """
    The given ``indices`` of size ``(set_size, subset_size)`` specifies subsets of the ``target``
    that each of the set_size rows should select. The `target` has size
    ``(batch_size, sequence_length, embedding_size)``, and the resulting selected tensor has size
    ``(batch_size, set_size, subset_size, embedding_size)``.

    Parameters
    ----------
    target : ``torch.Tensor``, required.
        A Tensor of shape (batch_size, sequence_length, embedding_size).
    indices : ``torch.LongTensor``, required.
        A LongTensor of shape (set_size, subset_size). All indices must be < sequence_length
        as this tensor is an index into the sequence_length dimension of the target.

    Returns
    -------
    selected : ``torch.Tensor``, required.
        A Tensor of shape (batch_size, set_size, subset_size, embedding_size).
    """
    if indices.dim() != 2:
        raise ConfigurationError("Indices passed to flattened_index_select had shape {} but "
                                 "only 2 dimensional inputs are supported.".format(indices.size()))
    # Shape: (batch_size, set_size * subset_size, embedding_size)
    flattened_selected = target.index_select(1, indices.view(-1))

    # Shape: (batch_size, set_size, subset_size, embedding_size)
    selected = flattened_selected.view(target.size(0), indices.size(0), indices.size(1), -1)
    return selected
Beispiel #2
0
    def _loss_helper(self,  # pylint: disable=inconsistent-return-statements
                     direction: int,
                     direction_embeddings: torch.Tensor,
                     direction_targets: torch.Tensor,
                     token_embeddings: torch.Tensor) -> Tuple[int, int]:
        mask = direction_targets > 0
        # we need to subtract 1 to undo the padding id since the softmax
        # does not include a padding dimension

        # shape (batch_size * timesteps, )
        non_masked_targets = direction_targets.masked_select(mask) - 1

        # shape (batch_size * timesteps, embedding_dim)
        non_masked_embeddings = direction_embeddings.masked_select(
                mask.unsqueeze(-1)
        ).view(-1, self._forward_dim)
        # note: need to return average loss across forward and backward
        # directions, but total sum loss across all batches.
        # Assuming batches include full sentences, forward and backward
        # directions have the same number of samples, so sum up loss
        # here then divide by 2 just below
        if not self._softmax_loss.tie_embeddings or not self._use_character_inputs:
            return self._softmax_loss(non_masked_embeddings, non_masked_targets)
        else:
            # we also need the token embeddings corresponding to the
            # the targets
            raise NotImplementedError("This requires SampledSoftmaxLoss, which isn't implemented yet.")
            # pylint: disable=unreachable
            non_masked_token_embeddings = self._get_target_token_embeddings(token_embeddings, mask, direction)
            return self._softmax(non_masked_embeddings,
                                 non_masked_targets,
                                 non_masked_token_embeddings)
 def split_heads(self, x: torch.Tensor, k: bool = False):
     new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
     x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
     if k:
         return x.permute(0, 2, 3, 1)
     else:
         return x.permute(0, 2, 1, 3)
Beispiel #4
0
    def forward(self, tokens: torch.Tensor, mask: torch.Tensor = None):  #pylint: disable=arguments-differ
        if mask is not None:
            tokens = tokens * mask.unsqueeze(-1).float()

        # Our input has shape `(batch_size, num_tokens, embedding_dim)`, so we sum out the `num_tokens`
        # dimension.
        summed = tokens.sum(1)

        if self._averaged:
            if mask is not None:
                lengths = get_lengths_from_binary_sequence_mask(mask)
                length_mask = (lengths > 0)

                # Set any length 0 to 1, to avoid dividing by zero.
                lengths = torch.max(lengths, Variable(lengths.data.new().resize_(1).fill_(1)))
            else:
                lengths = Variable(tokens.data.new().resize_(1).fill_(tokens.size(1)), requires_grad=False)
                length_mask = None

            summed = summed / lengths.unsqueeze(-1).float()

            if length_mask is not None:
                summed = summed * (length_mask > 0).float().unsqueeze(-1)

        return summed
Beispiel #5
0
def PeepholeLSTMCell(input: torch.Tensor,
                     hidden: Tuple[torch.Tensor, torch.Tensor],
                     w_ih: torch.Tensor,
                     w_hh: torch.Tensor,
                     w_ip: torch.Tensor,
                     w_fp: torch.Tensor,
                     w_op: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    An LSTM cell with peephole connections without biases.

    Mostly ripped from the pytorch autograd lstm implementation.
    """
    hx, cx = hidden
    gates = F.linear(input, w_ih) + F.linear(hx, w_hh)

    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
    peep_i = w_ip.unsqueeze(0).expand_as(cx) * cx
    ingate = ingate + peep_i
    peep_f = w_fp.unsqueeze(0).expand_as(cx) * cx
    forgetgate = forgetgate + peep_f

    ingate = F.sigmoid(ingate)
    forgetgate = F.sigmoid(forgetgate)
    cellgate = F.tanh(cellgate)
    cy = (forgetgate * cx) + (ingate * cellgate)
    peep_o = w_op.unsqueeze(0).expand_as(cy) * cy
    outgate = outgate + peep_o
    hy = outgate * F.tanh(cy)

    return hy, cy
Beispiel #6
0
def sample_regions(lb: Tensor, ub: Tensor, K: int, depth: int) -> Tuple[Tensor, Tensor]:
    """ Uniformly sample K sub-regions with fixed width boundaries for each sub-region.
    :param lb: Lower bounds, batched
    :param ub: Upper bounds, batched
    :param K: how many pieces to sample
    :param depth: bisecting original region width @depth times for sampling
    """
    assert valid_lb_ub(lb, ub)
    assert K >= 1 and depth >= 1

    repeat_dims = [1] * (len(lb.size()) - 1)
    base = lb.repeat(K, *repeat_dims)  # repeat K times in the batch, preserving the rest dimensions
    orig_width = ub - lb

    try:
        piece_width = orig_width / (2 ** depth)
        # print('Piece width:', piece_width)
        avail_width = orig_width - piece_width
    except RuntimeError as e:
        print('Numerical error at depth', depth)
        raise e

    piece_width = piece_width.repeat(K, *repeat_dims)
    avail_width = avail_width.repeat(K, *repeat_dims)

    coefs = torch.rand_like(base)
    lefts = base + coefs * avail_width
    rights = lefts + piece_width
    return lefts, rights
 def forward(self,  # pylint: disable=arguments-differ
             matrix_1: torch.Tensor,
             matrix_2: torch.Tensor) -> torch.Tensor:
     combined_tensors = util.combine_tensors_and_multiply(self._combination,
                                                          [matrix_1.unsqueeze(2), matrix_2.unsqueeze(1)],
                                                          self._weight_vector)
     return self._activation(combined_tensors + self._bias)
Beispiel #8
0
def neg_branin(X: Tensor) -> Tensor:
    r"""Negative Branin test function.

    Two-dimensional function (usually evaluated on `[-5, 10] x [0, 15]`):

        `B(x) = (x2 - b x_1^2 + c x_1 - r)^2 + 10 (1-t) cos(x_1) + 10`

    B has 3 minimizers for its global minimum at

        `z_1 = (-pi, 12.275), z_2 = (pi, 2.275), z_3 = (9.42478, 2.475)`

    with `B(z_i) = -0.397887`

    Args:
        X: A Tensor of size `2` or `k x 2` (`k` batch evaluations).

    Returns:
        `-B(X)`, the negative value of the standard Branin function.
    """
    batch = X.ndimension() > 1
    X = X if batch else X.unsqueeze(0)
    t1 = X[:, 1] - 5.1 / (4 * math.pi ** 2) * X[:, 0] ** 2 + 5 / math.pi * X[:, 0] - 6
    t2 = 10 * (1 - 1 / (8 * math.pi)) * torch.cos(X[:, 0])
    B = t1 ** 2 + t2 + 10
    result = -B
    return result if batch else result.squeeze(0)
Beispiel #9
0
    def __call__(self,
                 predictions: torch.Tensor,
                 gold_labels: torch.Tensor,
                 mask: Optional[torch.Tensor] = None):
        """
        Parameters
        ----------
        predictions : ``torch.Tensor``, required.
            A tensor of predictions of shape (batch_size, ...).
        gold_labels : ``torch.Tensor``, required.
            A tensor of the same shape as ``predictions``.
        mask: ``torch.Tensor``, optional (default = None).
            A tensor of the same shape as ``predictions``.
        """
        predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, gold_labels, mask)

        if mask is not None:
            # We can multiply by the mask up front, because we're just checking equality below, and
            # this way everything that's masked will be equal.
            predictions = predictions * mask
            gold_labels = gold_labels * mask

        batch_size = predictions.size(0)
        predictions = predictions.view(batch_size, -1)
        gold_labels = gold_labels.view(batch_size, -1)

        # The .prod() here is functioning as a logical and.
        correct = predictions.eq(gold_labels).prod(dim=1).float()
        count = torch.ones(gold_labels.size(0))
        self._correct_count += correct.sum()
        self._total_count += count.sum()
Beispiel #10
0
def get_final_encoder_states(encoder_outputs: torch.Tensor,
                             mask: torch.Tensor,
                             bidirectional: bool = False) -> torch.Tensor:
    """
    Given the output from a ``Seq2SeqEncoder``, with shape ``(batch_size, sequence_length,
    encoding_dim)``, this method returns the final hidden state for each element of the batch,
    giving a tensor of shape ``(batch_size, encoding_dim)``.  This is not as simple as
    ``encoder_outputs[:, -1]``, because the sequences could have different lengths.  We use the
    mask (which has shape ``(batch_size, sequence_length)``) to find the final state for each batch
    instance.

    Additionally, if ``bidirectional`` is ``True``, we will split the final dimension of the
    ``encoder_outputs`` into two and assume that the first half is for the forward direction of the
    encoder and the second half is for the backward direction.  We will concatenate the last state
    for each encoder dimension, giving ``encoder_outputs[:, -1, :encoding_dim/2]`` concated with
    ``encoder_outputs[:, 0, encoding_dim/2:]``.
    """
    # These are the indices of the last words in the sequences (i.e. length sans padding - 1).  We
    # are assuming sequences are right padded.
    # Shape: (batch_size,)
    last_word_indices = mask.sum(1).long() - 1
    batch_size, _, encoder_output_dim = encoder_outputs.size()
    expanded_indices = last_word_indices.view(-1, 1, 1).expand(batch_size, 1, encoder_output_dim)
    # Shape: (batch_size, 1, encoder_output_dim)
    final_encoder_output = encoder_outputs.gather(1, expanded_indices)
    final_encoder_output = final_encoder_output.squeeze(1)  # (batch_size, encoder_output_dim)
    if bidirectional:
        final_forward_output = final_encoder_output[:, :(encoder_output_dim // 2)]
        final_backward_output = encoder_outputs[:, 0, (encoder_output_dim // 2):]
        final_encoder_output = torch.cat([final_forward_output, final_backward_output], dim=-1)
    return final_encoder_output
Beispiel #11
0
    def __call__(self,
                 predictions: torch.Tensor,
                 gold_labels: torch.Tensor,
                 mask: Optional[torch.Tensor] = None):
        """
        Parameters
        ----------
        predictions : ``torch.Tensor``, required.
            A tensor of predictions of shape (batch_size, ...).
        gold_labels : ``torch.Tensor``, required.
            A tensor of the same shape as ``predictions``.
        mask: ``torch.Tensor``, optional (default = None).
            A tensor of the same shape as ``predictions``.
        """
        predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, gold_labels, mask)
        # Flatten predictions, gold_labels, and mask. We calculate the covariance between
        # the vectors, since each element in the predictions and gold_labels tensor is assumed
        # to be a separate observation.
        predictions = predictions.view(-1)
        gold_labels = gold_labels.view(-1)

        if mask is not None:
            mask = mask.view(-1)
            predictions = predictions * mask
            gold_labels = gold_labels * mask
            num_batch_items = torch.sum(mask).item()
        else:
            num_batch_items = gold_labels.numel()

        # Note that self._total_count must be a float or int at all times
        # If it is a 1-dimension Tensor, the previous count will equal the updated_count.
        # The sampe applies for previous_total_prediction_mean and
        # previous_total_label_mean below -- we handle this in the code by
        # calling .item() judiciously.
        previous_count = self._total_count
        updated_count = self._total_count + num_batch_items

        batch_mean_prediction = torch.sum(predictions) / num_batch_items
        delta_mean_prediction = ((batch_mean_prediction - self._total_prediction_mean) *
                                 num_batch_items) / updated_count
        previous_total_prediction_mean = self._total_prediction_mean
        self._total_prediction_mean += delta_mean_prediction.item()

        batch_mean_label = torch.sum(gold_labels) / num_batch_items
        delta_mean_label = ((batch_mean_label - self._total_label_mean) * num_batch_items) / updated_count
        previous_total_label_mean = self._total_label_mean
        self._total_label_mean += delta_mean_label.item()

        batch_coresiduals = (predictions - batch_mean_prediction) * (gold_labels - batch_mean_label)
        if mask is not None:
            batch_co_moment = torch.sum(batch_coresiduals * mask)
        else:
            batch_co_moment = torch.sum(batch_coresiduals)
        delta_co_moment = (
                batch_co_moment + (previous_total_prediction_mean - batch_mean_prediction) *
                (previous_total_label_mean - batch_mean_label) *
                (previous_count * num_batch_items / updated_count))
        self._total_co_moment += delta_co_moment.item()
        self._total_count = updated_count
Beispiel #12
0
def l2_distance(x: torch.Tensor, y: torch.Tensor) \
        -> torch.Tensor:
    """Compute the Gram matrix holding all ||.||_2 distances."""
    xTy = 2 * x.matmul(y.transpose(0, 1))
    x2 = torch.sum(x ** 2, dim=1)[:, None]
    y2 = torch.sum(y ** 2, dim=1)[None, :]
    K = x2 + y2 - xTy
    return K
Beispiel #13
0
 def forward(self, x: torch.Tensor) -> torch.Tensor:
     if self.rf == 1:
         size_out = x.size()[:-1] + (self.nf,)
         x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w)
         x = x.view(*size_out)
     else:
         raise NotImplementedError
     return x
Beispiel #14
0
def add_sentence_boundary_token_ids(tensor: torch.Tensor,
                                    mask: torch.Tensor,
                                    sentence_begin_token: Any,
                                    sentence_end_token: Any) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Add begin/end of sentence tokens to the batch of sentences.
    Given a batch of sentences with size ``(batch_size, timesteps)`` or
    ``(batch_size, timesteps, dim)`` this returns a tensor of shape
    ``(batch_size, timesteps + 2)`` or ``(batch_size, timesteps + 2, dim)`` respectively.

    Returns both the new tensor and updated mask.

    Parameters
    ----------
    tensor : ``torch.Tensor``
        A tensor of shape ``(batch_size, timesteps)`` or ``(batch_size, timesteps, dim)``
    mask : ``torch.Tensor``
         A tensor of shape ``(batch_size, timesteps)``
    sentence_begin_token: Any (anything that can be broadcast in torch for assignment)
        For 2D input, a scalar with the <S> id. For 3D input, a tensor with length dim.
    sentence_end_token: Any (anything that can be broadcast in torch for assignment)
        For 2D input, a scalar with the </S> id. For 3D input, a tensor with length dim.

    Returns
    -------
    tensor_with_boundary_tokens : ``torch.Tensor``
        The tensor with the appended and prepended boundary tokens. If the input was 2D,
        it has shape (batch_size, timesteps + 2) and if the input was 3D, it has shape
        (batch_size, timesteps + 2, dim).
    new_mask : ``torch.Tensor``
        The new mask for the tensor, taking into account the appended tokens
        marking the beginning and end of the sentence.
    """
    # TODO: matthewp, profile this transfer
    sequence_lengths = mask.sum(dim=1).detach().cpu().numpy()
    tensor_shape = list(tensor.data.shape)
    new_shape = list(tensor_shape)
    new_shape[1] = tensor_shape[1] + 2
    tensor_with_boundary_tokens = tensor.new_zeros(*new_shape)
    if len(tensor_shape) == 2:
        tensor_with_boundary_tokens[:, 1:-1] = tensor
        tensor_with_boundary_tokens[:, 0] = sentence_begin_token
        for i, j in enumerate(sequence_lengths):
            tensor_with_boundary_tokens[i, j + 1] = sentence_end_token
        new_mask = (tensor_with_boundary_tokens != 0).long()
    elif len(tensor_shape) == 3:
        tensor_with_boundary_tokens[:, 1:-1, :] = tensor
        for i, j in enumerate(sequence_lengths):
            tensor_with_boundary_tokens[i, 0, :] = sentence_begin_token
            tensor_with_boundary_tokens[i, j + 1, :] = sentence_end_token
        new_mask = ((tensor_with_boundary_tokens > 0).long().sum(dim=-1) > 0).long()
    else:
        raise ValueError("add_sentence_boundary_token_ids only accepts 2D and 3D input")

    return tensor_with_boundary_tokens, new_mask
Beispiel #15
0
def _safe_sparse_mask(tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    """
    In PyTorch 1.0, Tensor._sparse_mask was changed to Tensor.sparse_mask.
    This wrapper allows AllenNLP to (temporarily) work with both 1.0 and 0.4.1.
    """
    # pylint: disable=protected-access
    try:
        return tensor.sparse_mask(mask)
    except AttributeError:
        # TODO(joelgrus): remove this and/or warn at some point
        return tensor._sparse_mask(mask)
Beispiel #16
0
def valid_lb_ub(lb: Tensor, ub: Tensor) -> bool:
    """ To be valid:
        (1) Size ==
        (2) LB <= UB
    """
    if lb.size() != ub.size():
        return False

    # '<=' will return a uint8 tensor of 1 or 0 for each element, it should have all 1s.
    rel = lb <= ub
    return torch.equal(rel, torch.ones_like(rel))
Beispiel #17
0
 def _get_target_token_embeddings(self,
                                  token_embeddings: torch.Tensor,
                                  mask: torch.Tensor,
                                  direction: int) -> torch.Tensor:
     # Need to shift the mask in the correct direction
     zero_col = token_embeddings.new_zeros(mask.size(0), 1).byte()
     if direction == 0:
         # forward direction, get token to right
         shifted_mask = torch.cat([zero_col, mask[:, 0:-1]], dim=1)
     else:
         shifted_mask = torch.cat([mask[:, 1:], zero_col], dim=1)
     return token_embeddings.masked_select(shifted_mask.unsqueeze(-1)).view(-1, self._forward_dim)
    def _greedy_decode(self,
                       head_tag_representation: torch.Tensor,
                       child_tag_representation: torch.Tensor,
                       attended_arcs: torch.Tensor,
                       mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions by decoding the unlabeled arcs
        independently for each word and then again, predicting the head tags of
        these greedily chosen arcs indpendently. Note that this method of decoding
        is not guaranteed to produce trees (i.e. there maybe be multiple roots,
        or cycles when children are attached to their parents).

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachements of a given word to all other words.

        Returns
        -------
        heads : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the greedily decoded heads of each word.
        """
        # Mask the diagonal, because the head of a word can't be itself.
        attended_arcs = attended_arcs + torch.diag(attended_arcs.new(mask.size(1)).fill_(-numpy.inf))
        # Mask padded tokens, because we only want to consider actual words as heads.
        if mask is not None:
            minus_mask = (1 - mask).byte().unsqueeze(2)
            attended_arcs.masked_fill_(minus_mask, -numpy.inf)

        # Compute the heads greedily.
        # shape (batch_size, sequence_length)
        _, heads = attended_arcs.max(dim=2)

        # Given the greedily predicted heads, decode their dependency tags.
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation,
                                              child_tag_representation,
                                              heads)
        _, head_tags = head_tag_logits.max(dim=2)
        return heads, head_tags
Beispiel #19
0
 def forward(self,  # pylint: disable=arguments-differ
             vector: torch.Tensor,
             matrix: torch.Tensor,
             matrix_mask: torch.Tensor = None) -> torch.Tensor:
     tiled_vector = vector.unsqueeze(1).expand(vector.size()[0],
                                               matrix.size()[1],
                                               vector.size()[1])
     similarities = self._similarity_function(tiled_vector, matrix)
     if self._normalize:
         return masked_softmax(similarities, matrix_mask)
     else:
         return similarities
Beispiel #20
0
    def forward(self,    # pylint: disable=arguments-differ
                inputs: torch.Tensor) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
        """
        Parameters
        ----------
        inputs : ``torch.autograd.Variable``
            Shape ``(batch_size, timesteps, 50)`` of character ids representing the current batch.
            We also accept tensors with additional optional dimensions:
            ``(batch_size, dim0, dim1, ..., dimn, timesteps, 50)``

        Returns
        -------
        Dict with keys:
        ``'elmo_representations'``: ``List[torch.autograd.Variable]``
            A ``num_output_representations`` list of ELMo representations for the input sequence.
            Each representation is shape ``(batch_size, timesteps, embedding_dim)``
        ``'mask'``:  ``torch.autograd.Variable``
            Shape ``(batch_size, timesteps)`` long tensor with sequence mask.
        """
        # reshape the input if needed
        original_shape = inputs.size()
        timesteps, num_characters = original_shape[-2:]
        if len(original_shape) > 3:
            reshaped_inputs = inputs.view(-1, timesteps, num_characters)
        else:
            reshaped_inputs = inputs

        # run the biLM
        bilm_output = self._elmo_lstm(reshaped_inputs)
        layer_activations = bilm_output['activations']
        mask_with_bos_eos = bilm_output['mask']

        # compute the elmo representations
        representations = []
        for i in range(len(self._scalar_mixes)):
            scalar_mix = getattr(self, 'scalar_mix_{}'.format(i))
            representation_with_bos_eos = scalar_mix(layer_activations, mask_with_bos_eos)
            representation_without_bos_eos, mask_without_bos_eos = remove_sentence_boundaries(
                    representation_with_bos_eos, mask_with_bos_eos
            )
            representations.append(self._dropout(representation_without_bos_eos))

        # reshape if necessary
        if len(original_shape) > 3:
            mask = mask_without_bos_eos.view(original_shape[:-1])
            elmo_representations = [representation.view(original_shape[:-1] + (-1, ))
                                    for representation in representations]
        else:
            mask = mask_without_bos_eos
            elmo_representations = representations

        return {'elmo_representations': elmo_representations, 'mask': mask}
Beispiel #21
0
def batched_index_select(target: torch.Tensor,
                         indices: torch.LongTensor,
                         flattened_indices: Optional[torch.LongTensor] = None) -> torch.Tensor:
    """
    The given ``indices`` of size ``(batch_size, d_1, ..., d_n)`` indexes into the sequence
    dimension (dimension 2) of the target, which has size ``(batch_size, sequence_length,
    embedding_size)``.

    This function returns selected values in the target with respect to the provided indices, which
    have size ``(batch_size, d_1, ..., d_n, embedding_size)``. This can use the optionally
    precomputed :func:`~flattened_indices` with size ``(batch_size * d_1 * ... * d_n)`` if given.

    An example use case of this function is looking up the start and end indices of spans in a
    sequence tensor. This is used in the
    :class:`~allennlp.models.coreference_resolution.CoreferenceResolver`. Model to select
    contextual word representations corresponding to the start and end indices of mentions. The key
    reason this can't be done with basic torch functions is that we want to be able to use look-up
    tensors with an arbitrary number of dimensions (for example, in the coref model, we don't know
    a-priori how many spans we are looking up).

    Parameters
    ----------
    target : ``torch.Tensor``, required.
        A 3 dimensional tensor of shape (batch_size, sequence_length, embedding_size).
        This is the tensor to be indexed.
    indices : ``torch.LongTensor``
        A tensor of shape (batch_size, ...), where each element is an index into the
        ``sequence_length`` dimension of the ``target`` tensor.
    flattened_indices : Optional[torch.Tensor], optional (default = None)
        An optional tensor representing the result of calling :func:~`flatten_and_batch_shift_indices`
        on ``indices``. This is helpful in the case that the indices can be flattened once and
        cached for many batch lookups.

    Returns
    -------
    selected_targets : ``torch.Tensor``
        A tensor with shape [indices.size(), target.size(-1)] representing the embedded indices
        extracted from the batch flattened target tensor.
    """
    if flattened_indices is None:
        # Shape: (batch_size * d_1 * ... * d_n)
        flattened_indices = flatten_and_batch_shift_indices(indices, target.size(1))

    # Shape: (batch_size * sequence_length, embedding_size)
    flattened_target = target.view(-1, target.size(-1))

    # Shape: (batch_size * d_1 * ... * d_n, embedding_size)
    flattened_selected = flattened_target.index_select(0, flattened_indices)
    selected_shape = list(indices.size()) + [target.size(-1)]
    # Shape: (batch_size, d_1, ..., d_n, embedding_size)
    selected_targets = flattened_selected.view(*selected_shape)
    return selected_targets
Beispiel #22
0
def replace_masked_values(tensor: torch.Tensor, mask: torch.Tensor, replace_with: float) -> torch.Tensor:
    """
    Replaces all masked values in ``tensor`` with ``replace_with``.  ``mask`` must be broadcastable
    to the same shape as ``tensor``. We require that ``tensor.dim() == mask.dim()``, as otherwise we
    won't know which dimensions of the mask to unsqueeze.
    """
    # We'll build a tensor of the same shape as `tensor`, zero out masked values, then add back in
    # the `replace_with` value.
    if tensor.dim() != mask.dim():
        raise ConfigurationError("tensor.dim() (%d) != mask.dim() (%d)" % (tensor.dim(), mask.dim()))
    one_minus_mask = 1.0 - mask
    values_to_add = replace_with * one_minus_mask
    return tensor * mask + values_to_add
Beispiel #23
0
 def decorated(cls: Any, X: Tensor) -> Any:
     if X.dim() < 2:
         raise ValueError(
             f"{type(cls).__name__} requires X to have at least 2 dimensions,"
             f" but received X with only {X.dim()} dimensions."
         )
     elif expected_q is not None and X.shape[-2] != expected_q:
         raise AssertionError(
             f"Expected X to be `batch_shape x q={expected_q} x d`, but"
             f" got X with shape {X.shape}."
         )
     X = X if X.dim() > 2 else X.unsqueeze(0)
     return method(cls, X)
Beispiel #24
0
 def forward(self, line: torch.Tensor) -> np.array:
     """
     Performs a forward pass on a torch tensor of a line with shape (C, H, W)
     and returns a numpy array (W, C).
     """
     # make CHW -> 1CHW
     line = line.to(self.device)
     line = line.unsqueeze(0)
     o = self.nn.nn(line)
     if o.size(2) != 1:
         raise KrakenInputException('Expected dimension 3 to be 1, actual {}'.format(o.size()))
     self.outputs = o.detach().squeeze().cpu().numpy()
     return self.outputs
Beispiel #25
0
    def _construct_loss(self,
                        arc_scores: torch.Tensor,
                        arc_tag_logits: torch.Tensor,
                        arc_tags: torch.Tensor,
                        mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the arc and tag loss for an adjacency matrix.

        Parameters
        ----------
        arc_scores : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a
            binary classification decision for whether an edge is present between two words.
        arc_tag_logits : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length, num_tags) used to generate
            a distribution over edge tags for a given edge.
        arc_tags : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length).
            The labels for every arc.
        mask : ``torch.Tensor``, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        Returns
        -------
        arc_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc loss.
        tag_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc tag loss.
        """
        float_mask = mask.float()
        arc_indices = (arc_tags != -1).float()
        # Make the arc tags not have negative values anywhere
        # (by default, no edge is indicated with -1).
        arc_tags = arc_tags * arc_indices
        arc_nll = self._arc_loss(arc_scores, arc_indices) * float_mask.unsqueeze(1) * float_mask.unsqueeze(2)
        # We want the mask for the tags to only include the unmasked words
        # and we only care about the loss with respect to the gold arcs.
        tag_mask = float_mask.unsqueeze(1) * float_mask.unsqueeze(2) * arc_indices

        batch_size, sequence_length, _, num_tags = arc_tag_logits.size()
        original_shape = [batch_size, sequence_length, sequence_length]
        reshaped_logits = arc_tag_logits.view(-1, num_tags)
        reshaped_tags = arc_tags.view(-1)
        tag_nll = self._tag_loss(reshaped_logits, reshaped_tags.long()).view(original_shape) * tag_mask

        valid_positions = tag_mask.sum()

        arc_nll = arc_nll.sum() / valid_positions.float()
        tag_nll = tag_nll.sum() / valid_positions.float()
        return arc_nll, tag_nll
def attention(query: torch.Tensor,
              key: torch.Tensor,
              value: torch.Tensor,
              mask: torch.Tensor = None,
              dropout: Callable = None) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute 'Scaled Dot Product Attention'"""
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn
Beispiel #27
0
    def __call__(self,
                 predictions: torch.Tensor,
                 gold_labels: torch.Tensor,
                 mask: Optional[torch.Tensor] = None):
        """
        Parameters
        ----------
        predictions : ``torch.Tensor``, required.
            A tensor of predictions of shape (batch_size, ..., num_classes).
        gold_labels : ``torch.Tensor``, required.
            A tensor of integer class label of shape (batch_size, ...). It must be the same
            shape as the ``predictions`` tensor without the ``num_classes`` dimension.
        mask: ``torch.Tensor``, optional (default = None).
            A masking tensor the same size as ``gold_labels``.
        """
        # Get the data from the Variables.
        predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, gold_labels, mask)

        num_classes = predictions.size(-1)
        if (gold_labels >= num_classes).any():
            raise ConfigurationError("A gold label passed to F1Measure contains an id >= {}, "
                                     "the number of classes.".format(num_classes))
        if mask is None:
            mask = ones_like(gold_labels)
        mask = mask.float()
        gold_labels = gold_labels.float()
        positive_label_mask = gold_labels.eq(self._positive_label).float()
        negative_label_mask = 1.0 - positive_label_mask

        argmax_predictions = predictions.max(-1)[1].float().squeeze(-1)

        # True Negatives: correct non-positive predictions.
        correct_null_predictions = (argmax_predictions !=
                                    self._positive_label).float() * negative_label_mask
        self._true_negatives += (correct_null_predictions.float() * mask).sum()

        # True Positives: correct positively labeled predictions.
        correct_non_null_predictions = (argmax_predictions ==
                                        self._positive_label).float() * positive_label_mask
        self._true_positives += (correct_non_null_predictions * mask).sum()

        # False Negatives: incorrect negatively labeled predictions.
        incorrect_null_predictions = (argmax_predictions !=
                                      self._positive_label).float() * positive_label_mask
        self._false_negatives += (incorrect_null_predictions * mask).sum()

        # False Positives: incorrect positively labeled predictions
        incorrect_non_null_predictions = (argmax_predictions ==
                                          self._positive_label).float() * negative_label_mask
        self._false_positives += (incorrect_non_null_predictions * mask).sum()
Beispiel #28
0
    def dot_prod_attention(self, h_t: torch.Tensor, src_encoding: torch.Tensor, src_encoding_att_linear: torch.Tensor,
                           mask: torch.Tensor=None) -> Tuple[torch.Tensor, torch.Tensor]:
        # (batch_size, src_sent_len)
        att_weight = torch.bmm(src_encoding_att_linear, h_t.unsqueeze(2)).squeeze(2)

        if mask is not None:
            att_weight.data.masked_fill_(mask.byte(), -float('inf'))

        softmaxed_att_weight = F.softmax(att_weight, dim=-1)

        att_view = (att_weight.size(0), 1, att_weight.size(1))
        # (batch_size, hidden_size)
        ctx_vec = torch.bmm(softmaxed_att_weight.view(*att_view), src_encoding).squeeze(1)

        return ctx_vec, softmaxed_att_weight
Beispiel #29
0
def get_mask_from_sequence_lengths(sequence_lengths: torch.Tensor, max_length: int) -> torch.Tensor:
    """
    Given a variable of shape ``(batch_size,)`` that represents the sequence lengths of each batch
    element, this function returns a ``(batch_size, max_length)`` mask variable.  For example, if
    our input was ``[2, 2, 3]``, with a ``max_length`` of 4, we'd return
    ``[[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0]]``.

    We require ``max_length`` here instead of just computing it from the input ``sequence_lengths``
    because it lets us avoid finding the max, then copying that value from the GPU to the CPU so
    that we can use it to construct a new tensor.
    """
    # (batch_size, max_length)
    ones = sequence_lengths.new_ones(sequence_lengths.size(0), max_length)
    range_tensor = ones.cumsum(dim=1)
    return (sequence_lengths.unsqueeze(1) >= range_tensor).long()
    def _run_mst_decoding(batch_energy: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        heads = []
        head_tags = []
        for energy, length in zip(batch_energy.detach().cpu(), lengths):
            scores, tag_ids = energy.max(dim=0)
            # Although we need to include the root node so that the MST includes it,
            # we do not want any word to be the parent of the root node.
            # Here, we enforce this by setting the scores for all word -> ROOT edges
            # edges to be 0.
            scores[0, :] = 0
            # Decode the heads. Because we modify the scores to prevent
            # adding in word -> ROOT edges, we need to find the labels ourselves.
            instance_heads, _ = decode_mst(scores.numpy(), length, has_labels=False)

            # Find the labels which correspond to the edges in the max spanning tree.
            instance_head_tags = []
            for child, parent in enumerate(instance_heads):
                instance_head_tags.append(tag_ids[parent, child].item())
            # We don't care what the head or tag is for the root token, but by default it's
            # not necesarily the same in the batched vs unbatched case, which is annoying.
            # Here we'll just set them to zero.
            instance_heads[0] = 0
            instance_head_tags[0] = 0
            heads.append(instance_heads)
            head_tags.append(instance_head_tags)
        return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(numpy.stack(head_tags))
Beispiel #31
0
    def _compute_new_states(
            state: WikiTablesDecoderState,
            log_probs: torch.Tensor,
            hidden_state: torch.Tensor,
            memory_cell: torch.Tensor,
            action_embeddings: torch.Tensor,
            attended_question: torch.Tensor,
            attention_weights: torch.Tensor,
            considered_actions: List[List[int]],
            allowed_actions: List[Set[int]],
            max_actions: int = None) -> List[WikiTablesDecoderState]:
        # Each group index here might get accessed multiple times, and doing the slicing operation
        # each time is more expensive than doing it once upfront.  These three lines give about a
        # 10% speedup in training time.  I also tried this with sorted_log_probs and
        # action_embeddings, but those get accessed for _each action_, so doing the splits there
        # didn't help.
        hidden_state = [x.squeeze(0) for x in hidden_state.split(1, 0)]
        memory_cell = [x.squeeze(0) for x in memory_cell.split(1, 0)]
        attended_question = [
            x.squeeze(0) for x in attended_question.split(1, 0)
        ]

        sorted_log_probs, sorted_actions = log_probs.sort(dim=-1,
                                                          descending=True)
        if max_actions is not None:
            # We might need a version of `sorted_log_probs` on the CPU later, but only if we need
            # to truncate the best states to `max_actions`.
            sorted_log_probs_cpu = sorted_log_probs.detach().cpu().numpy()
        if state.debug_info is not None:
            probs_cpu = log_probs.exp().detach().cpu().numpy().tolist()
        sorted_actions = sorted_actions.detach().cpu().numpy().tolist()
        best_next_states: Dict[int, List[Tuple[int, int,
                                               int]]] = defaultdict(list)
        for group_index, (batch_index, group_actions) in enumerate(
                zip(state.batch_indices, sorted_actions)):
            for action_index, action in enumerate(group_actions):
                # `action` is currently the index in `log_probs`, not the actual action ID.  To get
                # the action ID, we need to go through `considered_actions`.
                action = considered_actions[group_index][action]
                if action == -1:
                    # This was padding.
                    continue
                if allowed_actions is not None and action not in allowed_actions[
                        group_index]:
                    # This happens when our _decoder trainer_ wants us to only evaluate certain
                    # actions, likely because they are the gold actions in this state.  We just skip
                    # emitting any state that isn't allowed by the trainer, because constructing the
                    # new state can be expensive.
                    continue
                best_next_states[batch_index].append(
                    (group_index, action_index, action))
        new_states = []
        for batch_index, best_states in sorted(best_next_states.items()):
            if max_actions is not None:
                # We sorted previously by _group_index_, but we then combined by _batch_index_.  We
                # need to get the top next states for each _batch_ instance, so we sort all of the
                # instance's states again (across group index) by score.  We don't need to do this
                # if `max_actions` is None, because we'll be keeping all of the next states,
                # anyway.
                best_states.sort(key=lambda x: sorted_log_probs_cpu[x[:2]],
                                 reverse=True)
                best_states = best_states[:max_actions]
            for group_index, action_index, action in best_states:
                # We'll yield a bunch of states here that all have a `group_size` of 1, so that the
                # learning algorithm can decide how many of these it wants to keep, and it can just
                # regroup them later, as that's a really easy operation.
                batch_index = state.batch_indices[group_index]
                new_action_history = state.action_history[group_index] + [
                    action
                ]
                new_score = sorted_log_probs[group_index, action_index]

                # `action_index` is the index in the _sorted_ tensors, but the action embedding
                # matrix is _not_ sorted, so we need to get back the original, non-sorted action
                # index before we get the action embedding.
                action_embedding_index = sorted_actions[group_index][
                    action_index]
                action_embedding = action_embeddings[group_index,
                                                     action_embedding_index, :]
                production_rule = state.possible_actions[batch_index][action][
                    0]
                new_grammar_state = state.grammar_state[
                    group_index].take_action(production_rule)
                if state.checklist_state[0] is not None:
                    new_checklist_state = [
                        state.checklist_state[group_index].update(action)
                    ]
                else:
                    new_checklist_state = None
                if state.debug_info is not None:
                    debug_info = {
                        'considered_actions': considered_actions[group_index],
                        'question_attention': attention_weights[group_index],
                        'probabilities': probs_cpu[group_index],
                    }
                    new_debug_info = [
                        state.debug_info[group_index] + [debug_info]
                    ]
                else:
                    new_debug_info = None

                new_rnn_state = RnnState(
                    hidden_state[group_index], memory_cell[group_index],
                    action_embedding, attended_question[group_index],
                    state.rnn_state[group_index].encoder_outputs,
                    state.rnn_state[group_index].encoder_output_mask)
                new_state = WikiTablesDecoderState(
                    batch_indices=[batch_index],
                    action_history=[new_action_history],
                    score=[new_score],
                    rnn_state=[new_rnn_state],
                    grammar_state=[new_grammar_state],
                    action_embeddings=state.action_embeddings,
                    output_action_embeddings=state.output_action_embeddings,
                    action_biases=state.action_biases,
                    action_indices=state.action_indices,
                    possible_actions=state.possible_actions,
                    flattened_linking_scores=state.flattened_linking_scores,
                    actions_to_entities=state.actions_to_entities,
                    entity_types=state.entity_types,
                    world=state.world,
                    example_lisp_string=state.example_lisp_string,
                    checklist_state=new_checklist_state,
                    debug_info=new_debug_info)
                new_states.append(new_state)
        return new_states
Beispiel #32
0
    def forward(self, emb_inputs: torch.Tensor) -> torch.Tensor:
        """
        Forward calculation of DynamicRoutingLayer
        
        Args:
            emb_inputs (T), shape = (B, N, E), data_type = torch.float: embedded features tensors
        
        Returns:
            T, shape = (B, N = N_cap, O = ECap), data_type = torch.float: output of DynamicRoutingLayer
        """
        # Name the inputs' tensor for alignment
        emb_inputs.names = (
            'B',
            'N',
            'E',
        )

        # calculate number of interest capsules K
        # inputs: emb_inputs, shape = (B, N, E)
        # output: max_num_caps, int
        self.num_caps = self._dynamic_interest_number(emb_inputs.size('N'))

        # calculate priors = \hat(e)_{j|i} for each capsule
        # inputs: emb_inputs, shape = (B, N, E_i)
        # inputs: max_num_caps, int
        # output: priors, shape = (B, K, N, ECap)
        batch_size = emb_inputs.size('B')
        priors = torch.matmul(emb_inputs, self.S)
        priors = priors.unflatten('B', (
            (
                'B',
                batch_size,
            ),
            (
                'C',
                1,
            ),
        ))
        priors = priors.rename(None).repeat(1, self.num_caps, 1, 1)
        priors.names = ('B', 'K', 'N', 'ECap')

        # detach priors as priors_temp to prevent gradients from flowing
        # inputs: priors, shape = (B, K, N, ECap)
        # output: priors_temp, shape = (B, K, N, ECap)
        priors_temp = priors.detach()

        # initialize coupling coefficient by bij ∼ N(0, σ^2).
        # inputs: priors_temp, shape = (B, K, N, ECap)
        # output: coup_coefficient, shape = (B, K, N, ECap)
        coup_coefficient = torch.randn_like(priors_temp.rename(None),
                                            device=priors.device)
        coup_coefficient.names = priors_temp.names

        # update coupling coefficient by iterative dynamic routing process
        for _ in range(self.num_iter - 1):
            # take softmax along max_num_caps to calculate weights for behaviour capsule.
            # inputs: coup_coefficient, shape = (B, K, N, ECap)
            # output: weights, shape = (B, K, N, ECap)
            weights = torch.softmax(coup_coefficient, dim='K')

            # calculate z
            # inputs: weights, shape = (B, K, N, ECap)
            # inputs: u_hat, shape = (B, K, N, ECap)
            # output: z, shape = (B, K, ECap)
            z = (weights * priors_temp).sum(dim='N')

            # apply squashing non-linearity to z
            # inputs: z, shape = (B, K, ECap)
            # output: v, shape = (B, K, ECap)
            v = squash(z)

            # calculate dot product between v and \hat{u]_{j|i}
            # inputs: priors_temp, shape = (B, K, N, ECap)
            # inputs: v, shape = (B, K, ECap)
            # output: uv, shape = (B, K, N, ECap = 1)
            v_temp = v.unflatten('ECap', (
                (
                    'ECap',
                    v.size('ECap'),
                ),
                (
                    'N',
                    1,
                ),
            ))
            similarity = torch.matmul(priors_temp.rename(None),
                                      v_temp.rename(None))
            similarity.names = coup_coefficient.names

            # update bij for all behavior capsule i and interest capsule j
            # inputs: coup_coefficient, shape = (B, K, N, ECap)
            # inputs: similarity, shape = (B, K, N, ECap)
            # output: coup_coefficient, shape = (B, K, N, ECap)
            coup_coefficient = coup_coefficient + similarity

        # calculate output with the original u_hat without routing updates
        # inputs: priors, shape = (B, K, N, ECap)
        # inputs: coup_coefficient, shape = (B, K, N, ECap)
        # output: output, shape = (B, K', E)
        weights = torch.softmax(coup_coefficient, dim='K')
        z = (weights * priors).sum(dim='N')

        # apply squashing non-linearity to z
        # inputs: z, shape = (B, K, ECap)
        # output: output, shape = (B, K, ECap)
        output = squash(z)

        # rename output names to (B, N, O)
        output.names = (
            'B',
            'N',
            'O',
        )

        return output
Beispiel #33
0
def uniq(a: Tensor) -> Set:
    return set(torch.unique(a.cpu()).numpy())
Beispiel #34
0
def test_torch_Tensor(compress):
    t = Tensor(numpy.random.random((100, 100)))
    t_serialized = serialize(t, compress=compress)
    t_serialized_deserialized = deserialize(t_serialized, compressed=compress)
    assert (t == t_serialized_deserialized).all()
Beispiel #35
0
 def forward(ctx, x: Tensor, inplace: bool = False):
     ctx.save_for_backward(x)
     x_ts = torch.tanh_(F.softplus(x))
     return x.mul_(x_ts) if inplace else x.mul(x_ts)
Beispiel #36
0
 def sample_mask(self):
     keep = 1.0 - self.dropout
     self.mask = Variable(torch.bernoulli(Tensor(1, self.hidden_size).fill_(keep)))
Beispiel #37
0
    def step(
            self, Ybar_t: torch.Tensor, dec_state: Tuple[torch.Tensor,
                                                         torch.Tensor],
            enc_hiddens: torch.Tensor, enc_hiddens_proj: torch.Tensor,
            enc_masks: torch.Tensor
    ) -> Tuple[Tuple, torch.Tensor, torch.Tensor]:
        """ Compute one forward step of the LSTM decoder, including the attention computation.

        @param Ybar_t (Tensor): Concatenated Tensor of [Y_t o_prev], with shape (b, e + h). The input for the decoder,
                                where b = batch size, e = embedding size, h = hidden size.
        @param dec_state (tuple(Tensor, Tensor)): Tuple of tensors both with shape (b, h), where b = batch size, h = hidden size.
                First tensor is decoder's prev hidden state, second tensor is decoder's prev cell.
        @param enc_hiddens (Tensor): Encoder hidden states Tensor, with shape (b, src_len, h * 2), where b = batch size,
                                    src_len = maximum source length, h = hidden size.
        @param enc_hiddens_proj (Tensor): Encoder hidden states Tensor, projected from (h * 2) to h. Tensor is with shape (b, src_len, h),
                                    where b = batch size, src_len = maximum source length, h = hidden size.
        @param enc_masks (Tensor): Tensor of sentence masks shape (b, src_len),
                                    where b = batch size, src_len is maximum source length. 

        @returns dec_state (tuple (Tensor, Tensor)): Tuple of tensors both shape (b, h), where b = batch size, h = hidden size.
                First tensor is decoder's new hidden state, second tensor is decoder's new cell.
        @returns combined_output (Tensor): Combined output Tensor at timestep t, shape (b, h), where b = batch size, h = hidden size.
        @returns e_t (Tensor): Tensor of shape (b, src_len). It is attention scores distribution.
                                Note: You will not use this outside of this function.
                                      We are simply returning this value so that we can sanity check
                                      your implementation.
        """

        combined_output = None

        ### YOUR CODE HERE (~3 Lines)
        ### TODO:
        ###     1. Apply the decoder to `Ybar_t` and `dec_state`to obtain the new dec_state.
        ###     2. Split dec_state into its two parts (dec_hidden, dec_cell)
        ###     3. Compute the attention scores e_t, a Tensor shape (b, src_len).
        ###        Note: b = batch_size, src_len = maximum source length, h = hidden size.
        ###
        ###       Hints:
        ###         - dec_hidden is shape (b, h) and corresponds to h^dec_t in the PDF (batched)
        ###         - enc_hiddens_proj is shape (b, src_len, h) and corresponds to W_{attProj} h^enc (batched).
        ###         - Use batched matrix multiplication (torch.bmm) to compute e_t.
        ###         - To get the tensors into the right shapes for bmm, you will need to do some squeezing and unsqueezing.
        ###         - When using the squeeze() function make sure to specify the dimension you want to squeeze
        ###             over. Otherwise, you will remove the batch dimension accidentally, if batch_size = 1.
        ###
        ### Use the following docs to implement this functionality:
        ###     Batch Multiplication:
        ###        https://pytorch.org/docs/stable/torch.html#torch.bmm
        ###     Tensor Unsqueeze:
        ###         https://pytorch.org/docs/stable/torch.html#torch.unsqueeze
        ###     Tensor Squeeze:
        ###         https://pytorch.org/docs/stable/torch.html#torch.squeeze

        ### END YOUR CODE

        # Set e_t to -inf where enc_masks has 1
        if enc_masks is not None:
            e_t.data.masked_fill_(enc_masks.byte(), -float('inf'))

        ### YOUR CODE HERE (~6 Lines)
        ### TODO:
        ###     1. Apply softmax to e_t to yield alpha_t
        ###     2. Use batched matrix multiplication between alpha_t and enc_hiddens to obtain the
        ###         attention output vector, a_t.
        #$$     Hints:
        ###           - alpha_t is shape (b, src_len)
        ###           - enc_hiddens is shape (b, src_len, 2h)
        ###           - a_t should be shape (b, 2h)
        ###           - You will need to do some squeezing and unsqueezing.
        ###     Note: b = batch size, src_len = maximum source length, h = hidden size.
        ###
        ###     3. Concatenate dec_hidden with a_t to compute tensor U_t
        ###     4. Apply the combined output projection layer to U_t to compute tensor V_t
        ###     5. Compute tensor O_t by first applying the Tanh function and then the dropout layer.
        ###
        ### Use the following docs to implement this functionality:
        ###     Softmax:
        ###         https://pytorch.org/docs/stable/nn.html#torch.nn.functional.softmax
        ###     Batch Multiplication:
        ###        https://pytorch.org/docs/stable/torch.html#torch.bmm
        ###     Tensor View:
        ###         https://pytorch.org/docs/stable/tensors.html#torch.Tensor.view
        ###     Tensor Concatenation:
        ###         https://pytorch.org/docs/stable/torch.html#torch.cat
        ###     Tanh:
        ###         https://pytorch.org/docs/stable/torch.html#torch.tanh

        ### END YOUR CODE

        combined_output = O_t
        return dec_state, combined_output, e_t
Beispiel #38
0
def gem(x: torch.Tensor, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p),
                        (x.size(-2), x.size(-1))).pow(1. / p)
Beispiel #39
0
 def tensor2text(t: torch.Tensor):
     return " ".join([str(w) for w in t.detach().numpy().ravel()]) + "\n"
Beispiel #40
0
    def forward(
        self,  # pylint: disable=arguments-differ
        inputs: torch.Tensor,
        mask: torch.LongTensor = None,
    ) -> torch.FloatTensor:
        """
        Parameters
        ----------
        inputs : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, timesteps, input_dim)
        mask : ``torch.FloatTensor``, optional (default = None).
            A tensor of shape (batch_size, timesteps).

        Returns
        -------
        A tensor of shape (batch_size, timesteps, output_projection_dim),
        where output_projection_dim = input_dim by default.
        """
        num_heads = self._num_heads

        batch_size, timesteps, hidden_dim = inputs.size()
        if mask is None:
            mask = Variable(inputs.data.new(batch_size, timesteps).fill_(1.0))

        # Treat the queries, keys and values each as a ``num_heads`` size batch.
        # shape (num_heads, batch_size * timesteps, hidden_dim)
        inputs_per_head = inputs.repeat(num_heads, 1,
                                        1).view(num_heads,
                                                batch_size * timesteps,
                                                hidden_dim)
        # Do the projections for all the heads at once.
        # Then reshape the result as though it had a
        # (num_heads * batch_size) sized batch.
        queries_per_head = torch.bmm(inputs_per_head, self._query_projections)
        # shape (num_heads * batch_size, timesteps, attention_dim)
        queries_per_head = queries_per_head.view(num_heads * batch_size,
                                                 timesteps,
                                                 self._attention_dim)

        keys_per_head = torch.bmm(inputs_per_head, self._key_projections)
        # shape (num_heads * batch_size, timesteps, attention_dim)
        keys_per_head = keys_per_head.view(num_heads * batch_size, timesteps,
                                           self._attention_dim)

        values_per_head = torch.bmm(inputs_per_head, self._value_projections)
        # shape (num_heads * batch_size, timesteps, attention_dim)
        values_per_head = values_per_head.view(num_heads * batch_size,
                                               timesteps, self._values_dim)

        # shape (num_heads * batch_size, timesteps, timesteps)
        scaled_similarities = (
            torch.bmm(queries_per_head, keys_per_head.transpose(1, 2)) /
            self._scale)

        # Masking should go here
        causality_mask = subsequent_mask(timesteps).cuda()
        masked_scaled_similarities = scaled_similarities.masked_fill(
            causality_mask == 0, -1e9)

        # shape (num_heads * batch_size, timesteps, timesteps)
        # Normalise the distributions, using the same mask for all heads.
        attention = masked_softmax(masked_scaled_similarities,
                                   mask.repeat(num_heads, 1))
        attention = self._attention_dropout(attention)
        # This is doing the following batch-wise matrix multiplication:
        # (num_heads * batch_size, timesteps, timesteps) *
        # (num_heads * batch_size, timesteps, values_dim)
        # which is equivalent to a weighted sum of the values with respect to
        # the attention distributions for each element in the num_heads * batch_size
        # dimension.
        # shape (num_heads * batch_size, timesteps, values_dim)
        outputs = torch.bmm(attention, values_per_head)

        # Reshape back to original shape (batch_size, timesteps, num_heads * values_dim)
        # Note that we _cannot_ use a reshape here, because this tensor was created
        # with num_heads being the first dimension, so reshaping naively would not
        # throw an error, but give an incorrect result.
        outputs = torch.cat(torch.split(outputs, batch_size, dim=0), dim=-1)

        # Project back to original input size.
        # shape (batch_size, timesteps, input_size)
        outputs = self._output_projection(outputs)
        return outputs
Beispiel #41
0
 def __loss_lte_zero(parameter: torch.Tensor):
     return torch.where(parameter.le(0.), torch.ones(1),
                        torch.zeros(1)) * parameter.abs()
def get_image(distances: torch.Tensor, width: int, height: int) -> np.ndarray:
    distances = (
        distances.clamp_min(0).reshape((height, width)).detach().cpu().numpy()
    )
    distances = distances.astype(np.float32)
    return distances
 def aggregate_scores(self, batch_scores: Tensor):
     batch_scores = batch_scores.view(-1, batch_scores.size(-1))
     self.all_scores.append(batch_scores.tolist())
 def aggregate_targets(self, batch_targets: Tensor, batch_context=Dict[str, Any]):
     self.all_targets.append(batch_targets.flatten().tolist())
 def forward(self, x: torch.Tensor):
     x = x.permute((0, 3, 1, 2)).true_divide_(255)
     img_list = [x[i, :, :, :] for i in range(x.size()[0])]
     _, result = self.base_model(img_list)
     return result
Beispiel #46
0
def accuracy(y_pred: Tensor, y_true: Tensor):

    outputs = np.argmax(y_pred, axis=1)
    return np.mean(outputs.numpy() == y_true.detach().cpu().numpy())
Beispiel #47
0
def average_precision(
        outputs: torch.Tensor,
        targets: torch.Tensor,
        weights: Optional[torch.Tensor] = None,
        topk: Sequence[int] = (1, ),
) -> Sequence[torch.Tensor]:
    """Computes the average precision at `topk`.

    Args:
        outputs (torch.Tensor): NxK tensor that for each of the N examples
            indicates the probability of the example belonging to each of
            the K classes, according to the model.
        targets (torch.Tensor):  binary NxK tensort that encodes which of the K
            classes are associated with the N-th input
            (eg: a row [0, 1, 0, 1] indicates that the example is
            associated with classes 2 and 4)
        weights (torch.Tensor): importance for each sample
        topk (int, optional): The maximum number of predicted elements

    Returns:
        Sequence[torch.Tensor]: list of 1xK tensor,
        with average precision@topk_i for K classes
    """
    assert len(topk) == 1 and topk[0] == 1, "@K logic is not implemented yet"
    # outputs - [bs; num_classes] with scores
    # targets - [bs; num_classes] with binary labels
    outputs, targets, weights = preprocess_multi_label_metrics(
        outputs=outputs,
        targets=targets,
        weights=weights,
    )
    if outputs.numel() == 0:
        return torch.zeros(1)

    ap = torch.zeros(targets.size(1))

    # compute average precision for each class
    for class_i in range(targets.size(1)):
        # sort scores
        class_scores = outputs[:, class_i]
        class_targets = targets[:, class_i]
        _, sortind = torch.sort(class_scores, dim=0, descending=True)
        correct = class_targets[sortind]

        # compute true positive sums
        if weights is not None:
            class_weight = weights[sortind]
            weighted_correct = correct.float() * class_weight

            tp = weighted_correct.cumsum(0)
            rg = class_weight.cumsum(0)
        else:
            tp = correct.float().cumsum(0)
            rg = torch.arange(1, targets.size(0) + 1).float()

        # compute precision curve
        precision = tp.div(rg)

        # compute average precision
        ap[class_i] = precision[correct.bool()].sum() / max(
            float(correct.sum()), 1)

    return [ap]
Beispiel #48
0
def dropout_mask(x: torch.Tensor, sz: Collection[int], p: float):
    "Return a dropout mask of the same type as `x`, size `sz`, with probability `p` to cancel an element."

    return x.new(*sz).bernoulli_(1 - p).div_(1 - p)
Beispiel #49
0
        if do_dropout and self.dropout_method == 'moon':
                c_t.data.set_(torch.mul(c_t, self.mask).data)
                c_t.data *= 1.0/(1.0 - self.dropout)

        if self.cln:
            c_t = self.ln_cell(c_t, image_emb)
        else:
            c_t = self.ln_cell(c_t)
        h_t = torch.mul(o_t, c_t.tanh())

        # Reshape for compatibility
        if do_dropout:
            if self.dropout_method == 'pytorch':
                F.dropout(h_t, p=self.dropout, training=self.training, inplace=True)
            if self.dropout_method == 'gal':
                    h_t.data.set_(torch.mul(h_t, self.mask).data)
                    h_t.data *= 1.0/(1.0 - self.dropout)

        h_t = h_t.view(1, h_t.size(0), -1)
        c_t = c_t.view(1, c_t.size(0), -1)
        return h_t, (h_t, c_t)

if __name__ == '__main__':
    model = LSTM(50, 100, 2)
    x = Variable(Tensor(50, 32, 50))
    #h = model.init_hidden(32)
    h = (Variable(Tensor(2*2, 32, 100)),
         Variable(Tensor(2*2, 32, 100)))
    print(model(x, h))
def multi_head_attention_forward(query: Tensor,
                                 key: Tensor,
                                 value: Tensor,
                                 embed_dim_to_check: int,
                                 num_heads: int,
                                 in_proj_weight: Tensor,
                                 in_proj_bias: Tensor,
                                 bias_k: Optional[Tensor],
                                 bias_v: Optional[Tensor],
                                 add_zero_attn: bool,
                                 dropout_p: float,
                                 out_proj_weight: Tensor,
                                 out_proj_bias: Tensor,
                                 training: bool = True,
                                 key_padding_mask: Optional[Tensor] = None,
                                 need_weights: bool = True,
                                 attn_mask: Optional[Tensor] = None,
                                 use_separate_proj_weight: bool = False,
                                 q_proj_weight: Optional[Tensor] = None,
                                 k_proj_weight: Optional[Tensor] = None,
                                 v_proj_weight: Optional[Tensor] = None,
                                 static_k: Optional[Tensor] = None,
                                 static_v: Optional[Tensor] = None
                                 ) -> Tuple[Tensor, Optional[Tensor]]:
    r"""
    Args:
        query, key, value: map a query and a set of key-value pairs to an output.
            See "Attention Is All You Need" for more details.
        embed_dim_to_check: total dimension of the model.
        num_heads: parallel attention heads.
        in_proj_weight, in_proj_bias: input projection weight and bias.
        bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
        add_zero_attn: add a new batch of zeros to the key and
                       value sequences at dim=1.
        dropout_p: probability of an element to be zeroed.
        out_proj_weight, out_proj_bias: the output projection weight and bias.
        training: apply dropout if is ``True``.
        key_padding_mask: if provided, specified padding elements in the key will
            be ignored by the attention. This is an binary mask. When the value is True,
            the corresponding value on the attention layer will be filled with -inf.
        need_weights: output attn_output_weights.
        attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
            the batches while a 3D mask allows to specify a different mask for the entries of each batch.
        use_separate_proj_weight: the function accept the proj. weights for query, key,
            and value in different forms. If false, in_proj_weight will be used, which is
            a combination of q_proj_weight, k_proj_weight, v_proj_weight.
        q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
        static_k, static_v: static key and value used for attention operators.
    Shape:
        Inputs:
        - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
          the embedding dimension.
        - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
          If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
          will be unchanged. If a BoolTensor is provided, the positions with the
          value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
        - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
          3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
          S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
          positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
          while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
          are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
          is provided, it will be added to the attention weight.
        - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
        - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
        Outputs:
        - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
          E is the embedding dimension.
        - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
          L is the target sequence length, S is the source sequence length.
    """
    if not torch.jit.is_scripting():
        tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
                    out_proj_weight, out_proj_bias)
        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
            return handle_torch_function(
                multi_head_attention_forward, tens_ops, query, key, value,
                embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
                bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
                out_proj_bias, training=training, key_padding_mask=key_padding_mask,
                need_weights=need_weights, attn_mask=attn_mask,
                use_separate_proj_weight=use_separate_proj_weight,
                q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
                v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
    tgt_len, bsz, embed_dim = query.size()
    assert embed_dim == embed_dim_to_check
    # allow MHA to have different sizes for the feature dimension
    assert key.size(0) == value.size(0) and key.size(1) == value.size(1)

    head_dim = embed_dim // num_heads
    assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
    scaling = float(head_dim) ** -0.5

    if not use_separate_proj_weight:
        if torch.equal(query, key) and torch.equal(key, value):
            # self-attention
            q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)

        elif torch.equal(key, value):
            # encoder-decoder attention
            # This is inline in_proj function with in_proj_weight and in_proj_bias
            _b = in_proj_bias
            _start = 0
            _end = embed_dim
            _w = in_proj_weight[_start:_end, :]
            if _b is not None:
                _b = _b[_start:_end]
            q = linear(query, _w, _b)

            if key is None:
                assert value is None
                k = None
                v = None
            else:

                # This is inline in_proj function with in_proj_weight and in_proj_bias
                _b = in_proj_bias
                _start = embed_dim
                _end = None
                _w = in_proj_weight[_start:, :]
                if _b is not None:
                    _b = _b[_start:]
                k, v = linear(key, _w, _b).chunk(2, dim=-1)

        else:
            # This is inline in_proj function with in_proj_weight and in_proj_bias
            _b = in_proj_bias
            _start = 0
            _end = embed_dim
            _w = in_proj_weight[_start:_end, :]
            if _b is not None:
                _b = _b[_start:_end]
            q = linear(query, _w, _b)

            # This is inline in_proj function with in_proj_weight and in_proj_bias
            _b = in_proj_bias
            _start = embed_dim
            _end = embed_dim * 2
            _w = in_proj_weight[_start:_end, :]
            if _b is not None:
                _b = _b[_start:_end]
            k = linear(key, _w, _b)

            # This is inline in_proj function with in_proj_weight and in_proj_bias
            _b = in_proj_bias
            _start = embed_dim * 2
            _end = None
            _w = in_proj_weight[_start:, :]
            if _b is not None:
                _b = _b[_start:]
            v = linear(value, _w, _b)
    else:
        q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
        len1, len2 = q_proj_weight_non_opt.size()
        assert len1 == embed_dim and len2 == query.size(-1)

        k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
        len1, len2 = k_proj_weight_non_opt.size()
        assert len1 == embed_dim and len2 == key.size(-1)

        v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
        len1, len2 = v_proj_weight_non_opt.size()
        assert len1 == embed_dim and len2 == value.size(-1)

        if in_proj_bias is not None:
            q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
            k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
            v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
        else:
            q = linear(query, q_proj_weight_non_opt, in_proj_bias)
            k = linear(key, k_proj_weight_non_opt, in_proj_bias)
            v = linear(value, v_proj_weight_non_opt, in_proj_bias)
    q = q * scaling

    if attn_mask is not None:
        assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
            attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
            'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
        if attn_mask.dtype == torch.uint8:
            warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
            attn_mask = attn_mask.to(torch.bool)

        if attn_mask.dim() == 2:
            attn_mask = attn_mask.unsqueeze(0)
            if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
                raise RuntimeError('The size of the 2D attn_mask is not correct.')
        elif attn_mask.dim() == 3:
            if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
                raise RuntimeError('The size of the 3D attn_mask is not correct.')
        else:
            raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
        # attn_mask's dim is 3 now.

    # convert ByteTensor key_padding_mask to bool
    if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
        warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
        key_padding_mask = key_padding_mask.to(torch.bool)

    if bias_k is not None and bias_v is not None:
        if static_k is None and static_v is None:
            k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = pad(attn_mask, (0, 1))
            if key_padding_mask is not None:
                key_padding_mask = pad(key_padding_mask, (0, 1))
        else:
            assert static_k is None, "bias cannot be added to static key."
            assert static_v is None, "bias cannot be added to static value."
    else:
        assert bias_k is None
        assert bias_v is None

    q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    if k is not None:
        k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    if v is not None:
        v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)

    if static_k is not None:
        assert static_k.size(0) == bsz * num_heads
        assert static_k.size(2) == head_dim
        k = static_k

    if static_v is not None:
        assert static_v.size(0) == bsz * num_heads
        assert static_v.size(2) == head_dim
        v = static_v

    src_len = k.size(1)

    if key_padding_mask is not None:
        assert key_padding_mask.size(0) == bsz
        assert key_padding_mask.size(1) == src_len

    if add_zero_attn:
        src_len += 1
        k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
        v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
        if attn_mask is not None:
            attn_mask = pad(attn_mask, (0, 1))
        if key_padding_mask is not None:
            key_padding_mask = pad(key_padding_mask, (0, 1))

    attn_output_weights = torch.bmm(q, k.transpose(1, 2))
    assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_output_weights.masked_fill_(attn_mask, float('-inf'))
        else:
            attn_output_weights += attn_mask


    if key_padding_mask is not None:
        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
        attn_output_weights = attn_ output_weights.masked_fill(
            key_padding_mask.unsqueeze(1).unsqueeze(2),
            float('-inf'),
        )
        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)

    attn_output_weights = softmax(
        attn_output_weights, dim=-1)
    attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)

    attn_output = torch.bmm(attn_output_weights, v)
    assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
    attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
    attn_output = linear(attn_output, out_proj_weight, out_proj_bias) ## may be eliminated, also biases in other places

    if need_weights:
        # average attention weights over heads
        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
        return attn_output, attn_output_weights.sum(dim=1) / num_heads
    else:
        return attn_output, None
Beispiel #51
0
    def decode(self, enc_hiddens: torch.Tensor, enc_masks: torch.Tensor,
               dec_init_state: Tuple[torch.Tensor, torch.Tensor],
               target_padded: torch.Tensor) -> torch.Tensor:
        """Compute combined output vectors for a batch.

        @param enc_hiddens (Tensor): Hidden states (b, src_len, h*2), where
                                     b = batch size, src_len = maximum source sentence length, h = hidden size.
        @param enc_masks (Tensor): Tensor of sentence masks (b, src_len), where
                                     b = batch size, src_len = maximum source sentence length.
        @param dec_init_state (tuple(Tensor, Tensor)): Initial state and cell for decoder
        @param target_padded (Tensor): Gold-standard padded target sentences (tgt_len, b), where
                                       tgt_len = maximum target sentence length, b = batch size. 

        @returns combined_outputs (Tensor): combined output tensor  (tgt_len, b,  h), where
                                        tgt_len = maximum target sentence length, b = batch_size,  h = hidden size
        """
        # Chop of the <END> token for max length sentences.
        target_padded = target_padded[:-1]

        # Initialize the decoder state (hidden and cell)
        dec_state = dec_init_state

        # Initialize previous combined output vector o_{t-1} as zero
        batch_size = enc_hiddens.size(0)
        o_prev = torch.zeros(batch_size, self.hidden_size, device=self.device)

        # Initialize a list we will use to collect the combined output o_t on each step
        combined_outputs = []

        ### YOUR CODE HERE (~9 Lines)
        ### TODO:
        ###     1. Apply the attention projection layer to `enc_hiddens` to obtain `enc_hiddens_proj`,
        ###         which should be shape (b, src_len, h),
        ###         where b = batch size, src_len = maximum source length, h = hidden size.
        ###         This is applying W_{attProj} to h^enc, as described in the PDF.
        ###     2. Construct tensor `Y` of target sentences with shape (tgt_len, b, e) using the target model embeddings.
        ###         where tgt_len = maximum target sentence length, b = batch size, e = embedding size.
        ###     3. Use the torch.split function to iterate over the time dimension of Y.
        ###         Within the loop, this will give you Y_t of shape (1, b, e) where b = batch size, e = embedding size.
        ###             - Squeeze Y_t into a tensor of dimension (b, e).
        ###             - Construct Ybar_t by concatenating Y_t with o_prev.
        ###             - Use the step function to compute the the Decoder's next (cell, state) values
        ###               as well as the new combined output o_t.
        ###             - Append o_t to combined_outputs
        ###             - Update o_prev to the new o_t.
        ###     4. Use torch.stack to convert combined_outputs from a list length tgt_len of
        ###         tensors shape (b, h), to a single tensor shape (tgt_len, b, h)
        ###         where tgt_len = maximum target sentence length, b = batch size, h = hidden size.
        ###
        ### Note:
        ###    - When using the squeeze() function make sure to specify the dimension you want to squeeze
        ###      over. Otherwise, you will remove the batch dimension accidentally, if batch_size = 1.
        ###
        ### Use the following docs to implement this functionality:
        ###     Zeros Tensor:
        ###         https://pytorch.org/docs/stable/torch.html#torch.zeros
        ###     Tensor Splitting (iteration):
        ###         https://pytorch.org/docs/stable/torch.html#torch.split
        ###     Tensor Dimension Squeezing:
        ###         https://pytorch.org/docs/stable/torch.html#torch.squeeze
        ###     Tensor Concatenation:
        ###         https://pytorch.org/docs/stable/torch.html#torch.cat
        ###     Tensor Stacking:
        ###         https://pytorch.org/docs/stable/torch.html#torch.stack

        ### END YOUR CODE

        return combined_outputs
Beispiel #52
0
def rotation_matrix_to_quaternion(
        rotation_matrix: torch.Tensor,
        eps: float = 1e-8) -> torch.Tensor:
    r"""Convert 3x3 rotation matrix to 4d quaternion vector.

    Args:
        rotation_matrix (torch.Tensor): the rotation matrix to convert.
        eps (float): small value to avoid zero division. Default: 1e-8.

    Return:
        torch.Tensor: the rotation in quaternion.

    Shape:
        - Input: :math:`(*, 3, 3)`
        - Output: :math:`(*, 4)`

    Example:
        >>> input = torch.rand(4, 3, 4)  # Nx3x4
        >>> output = kornia.rotation_matrix_to_quaternion(input)  # Nx4
    """
    if not isinstance(rotation_matrix, torch.Tensor):
        raise TypeError("Input type is not a torch.Tensor. Got {}".format(
            type(rotation_matrix)))

    if not rotation_matrix.shape[-2:] == (3, 3):
        raise ValueError(
            "Input size must be a (*, 3, 3) tensor. Got {}".format(
                rotation_matrix.shape))

    def safe_zero_division(numerator: torch.Tensor,
                           denominator: torch.Tensor) -> torch.Tensor:
        eps: float = torch.finfo(numerator.dtype).tiny  # type: ignore
        return numerator / torch.clamp(denominator, min=eps)

    rotation_matrix_vec: torch.Tensor = rotation_matrix.view(
        *rotation_matrix.shape[:-2], 9)

    m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.chunk(
        rotation_matrix_vec, chunks=9, dim=-1)

    trace: torch.Tensor = m00 + m11 + m22

    def trace_positive_cond():
        sq = torch.sqrt(trace + 1.0) * 2.  # sq = 4 * qw.
        qw = 0.25 * sq
        qx = safe_zero_division(m21 - m12, sq)
        qy = safe_zero_division(m02 - m20, sq)
        qz = safe_zero_division(m10 - m01, sq)
        return torch.cat([qx, qy, qz, qw], dim=-1)

    def cond_1():
        sq = torch.sqrt(1.0 + m00 - m11 - m22 + eps) * 2.  # sq = 4 * qw.
        qw = safe_zero_division(m21 - m12, sq)
        qx = 0.25 * sq
        qy = safe_zero_division(m01 - m10, sq)
        qz = safe_zero_division(m02 - m20, sq)
        return torch.cat([qx, qy, qz, qw], dim=-1)

    def cond_2():
        sq = torch.sqrt(1.0 + m00 - m11 - m22 + eps) * 2.  # sq = 4 * qw.
        qw = safe_zero_division(m02 - m20, sq)
        qx = safe_zero_division(m01 - m10, sq)
        qy = 0.25 * sq
        qz = safe_zero_division(m12 - m21, sq)
        return torch.cat([qx, qy, qz, qw], dim=-1)

    def cond_3():
        sq = torch.sqrt(1.0 + m00 - m11 - m22 + eps) * 2.  # sq = 4 * qw.
        qw = safe_zero_division(m10 - m01, sq)
        qx = safe_zero_division(m02 - m20, sq)
        qy = safe_zero_division(m12 - m21, sq)
        qz = 0.25 * sq
        return torch.cat([qx, qy, qz, qw], dim=-1)

    where_2 = torch.where(m11 > m22, cond_2(), cond_3())
    where_1 = torch.where(
        (m00 > m22) & (m00 > m22), cond_1(), where_2)

    quaternion: torch.Tensor = torch.where(
        trace > 0., trace_positive_cond(), where_1)
    return quaternion
Beispiel #53
0
def append_token(data: torch.Tensor, eos_token):
    start_token = torch.ones((data.size(0), 1), dtype=data.dtype) * eos_token
    end_token = torch.ones((data.size(0), 1), dtype=data.dtype) * eos_token

    return torch.cat([start_token, data, end_token], -1)
    def forward(
        self,
        x1: torch.Tensor,
        x2: torch.Tensor,
        diag: Optional[bool] = False,
        last_dim_is_batch: Optional[bool] = False,
        **params,
    ) -> torch.Tensor:
        offset = self.offset.view(*self.batch_shape, 1, 1)

        batch_shape = x1.shape[:-2]
        n1, d = x1.shape[-2:]
        n2 = x2.shape[-2]

        if diag:
            base_diag = (x1 * x2).sum(dim=-1) + self.offset
            K11_diag = base_diag.pow(self.power)

            all_outers_diag = (x1 * x2).transpose(-2, -1).reshape(
                *batch_shape, -1)
            K22_base_diag = self.power * (self.power -
                                          1) * base_diag.pow(self.power - 2)
            K12_base_diag = self.power * base_diag.pow(self.power - 1)

            K22_diag = torch.add(
                all_outers_diag *
                K22_base_diag.repeat(*([1] * (K22_base_diag.dim() - 1)), d),
                K12_base_diag.repeat(*([1] * (K12_base_diag.dim() - 1)), d),
            )

            K_diag = torch.cat([K11_diag, K22_diag], dim=-1)
            # Apply perfect shuffle
            pi1 = torch.arange(n1 * (d + 1)).view(d + 1, n1).t().reshape(
                (n1 * (d + 1)))
            K_diag = K_diag[..., pi1]
            return K_diag

        else:
            base_inner_prod = torch.matmul(x1, x2.transpose(-2, -1)) + offset
            K11 = base_inner_prod.pow(self.power)

            K12_base = self.power * base_inner_prod.pow(self.power - 1)
            K12 = torch.zeros(*batch_shape,
                              n1,
                              n2 * d,
                              dtype=x1.dtype,
                              device=x1.device)

            ones_ = torch.ones(*batch_shape,
                               d,
                               1,
                               n2,
                               dtype=x1.dtype,
                               device=x1.device)
            K12_outer_prods = torch.matmul(
                x1.transpose(-2, -1).unsqueeze(-1), ones_)
            K12 = (K12_base.unsqueeze(-3) * K12_outer_prods).transpose(
                -3, -2).reshape(*batch_shape, n1, d * n2)

            ones_ = torch.ones(*batch_shape,
                               d,
                               n1,
                               1,
                               dtype=x1.dtype,
                               device=x1.device)
            K21_outer_prods = torch.matmul(ones_,
                                           x2.transpose(-2, -1).unsqueeze(-2))
            K21 = (K12_base.unsqueeze(-3) * K21_outer_prods).view(
                *batch_shape, d * n1, n2)

            K22_base = self.power * (self.power -
                                     1) * base_inner_prod.pow(self.power - 2)
            K22 = torch.zeros(*batch_shape,
                              n1 * d,
                              n2 * d,
                              dtype=x1.dtype,
                              device=x1.device)
            all_outers = x1.unsqueeze(-2).unsqueeze(-2).transpose(
                -2, -1).matmul(x2.unsqueeze(-3).unsqueeze(-2))
            all_outers = all_outers.transpose(-4, -2).transpose(-3, -1)
            K22 = K22_base.unsqueeze(-3).unsqueeze(
                -3) * all_outers  # d x d x n1 x n2

            # Can't avoid this for loop without unnecessary memory duplication, which is worse.
            for i in range(d):
                K22[..., i, i, :, :] = K22[..., i, i, :, :] + K12_base

            K22 = K22.transpose(-4, -3).transpose(-3, -2).reshape(
                *batch_shape, n1 * d, n2 * d)

            K = torch.cat(
                [torch.cat([K11, K12], dim=-1),
                 torch.cat([K21, K22], dim=-1)],
                dim=-2)

            # Apply perfect shuffle
            pi1 = torch.arange(n1 * (d + 1)).view(d + 1, n1).t().reshape(
                (n1 * (d + 1)))
            pi2 = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape(
                (n2 * (d + 1)))
            K = K[..., pi1, :][..., :, pi2]

            return K
Beispiel #55
0
 def forward(ctx, x: Tensor, inplace: bool = False):
     ctx.save_for_backward(x)
     x_sigmoid = torch.sigmoid(x)
     return x.mul_(x_sigmoid) if inplace else x.mul(x_sigmoid)
    def _gather_final_log_probs(
        self,
        generation_log_probs: torch.Tensor,
        copy_log_probs: torch.Tensor,
        state: Dict[str, torch.Tensor],
    ) -> torch.Tensor:
        """
        Combine copy probabilities with generation probabilities for matching tokens.

        # Parameters

        generation_log_probs : `torch.Tensor`
            Shape: `(group_size, target_vocab_size)`
        copy_log_probs : `torch.Tensor`
            Shape: `(group_size, source_sequence_length)`
        state : `Dict[str, torch.Tensor]`

        # Returns

        torch.Tensor
            Shape: `(group_size, target_vocab_size + source_sequence_length)`.
        """
        _, source_sequence_length = state["source_to_target"].size()
        source_token_ids = state["source_token_ids"]

        # shape: [(batch_size, *)]
        modified_log_probs_list: List[torch.Tensor] = []
        for i in range(source_sequence_length):
            # shape: (group_size,)
            copy_log_probs_slice = copy_log_probs[:, i]
            # `source_to_target` is a matrix of shape (group_size, source_sequence_length)
            # where element (i, j) is the vocab index of the target token that matches the jth
            # source token in the ith group, if there is one, or the index of the OOV symbol otherwise.
            # We'll use this to add copy scores to corresponding generation scores.
            # shape: (group_size,)
            source_to_target_slice = state["source_to_target"][:, i]
            # The OOV index in the source_to_target_slice indicates that the source
            # token is not in the target vocab, so we don't want to add that copy score
            # to the OOV token.
            copy_log_probs_to_add_mask = source_to_target_slice != self._oov_index
            copy_log_probs_to_add = (
                copy_log_probs_slice +
                (copy_log_probs_to_add_mask +
                 util.tiny_value_of_dtype(copy_log_probs_slice.dtype)).log())
            # shape: (batch_size, 1)
            copy_log_probs_to_add = copy_log_probs_to_add.unsqueeze(-1)
            # shape: (batch_size, 1)
            selected_generation_log_probs = generation_log_probs.gather(
                1, source_to_target_slice.unsqueeze(-1))
            combined_scores = util.logsumexp(
                torch.cat(
                    (selected_generation_log_probs, copy_log_probs_to_add),
                    dim=1))
            generation_log_probs = generation_log_probs.scatter(
                -1, source_to_target_slice.unsqueeze(-1),
                combined_scores.unsqueeze(-1))
            # We have to combine copy scores for duplicate source tokens so that
            # we can find the overall most likely source token. So, if this is the first
            # occurence of this particular source token, we add the log_probs from all other
            # occurences, otherwise we zero it out since it was already accounted for.
            if i < (source_sequence_length - 1):
                # Sum copy scores from future occurences of source token.
                # shape: (group_size, source_sequence_length - i)
                source_future_occurences = source_token_ids[:, (
                    i + 1):] == source_token_ids[:, i].unsqueeze(-1)
                # shape: (group_size, source_sequence_length - i)
                future_copy_log_probs = (
                    copy_log_probs[:, (i + 1):] +
                    (source_future_occurences +
                     util.tiny_value_of_dtype(copy_log_probs.dtype)).log())
                # shape: (group_size, 1 + source_sequence_length - i)
                combined = torch.cat((copy_log_probs_slice.unsqueeze(-1),
                                      future_copy_log_probs),
                                     dim=-1)
                # shape: (group_size,)
                copy_log_probs_slice = util.logsumexp(combined)
            if i > 0:
                # Remove copy log_probs that we have already accounted for.
                # shape: (group_size, i)
                source_previous_occurences = source_token_ids[:, 0:
                                                              i] == source_token_ids[:, i].unsqueeze(
                                                                  -1)
                # shape: (group_size,)
                duplicate_mask = source_previous_occurences.sum(dim=-1) == 0
                copy_log_probs_slice = (
                    copy_log_probs_slice +
                    (duplicate_mask + util.tiny_value_of_dtype(
                        copy_log_probs_slice.dtype)).log())

            # Finally, we zero-out copy scores that we added to the generation scores
            # above so that we don't double-count them.
            # shape: (group_size,)
            left_over_copy_log_probs = (
                copy_log_probs_slice +
                (~copy_log_probs_to_add_mask +
                 util.tiny_value_of_dtype(copy_log_probs_slice.dtype)).log())
            modified_log_probs_list.append(
                left_over_copy_log_probs.unsqueeze(-1))
        modified_log_probs_list.insert(0, generation_log_probs)

        # shape: (group_size, target_vocab_size + source_sequence_length)
        modified_log_probs = torch.cat(modified_log_probs_list, dim=-1)

        return modified_log_probs
Beispiel #57
0
def simplex(t: Tensor, axis=1) -> bool:
    _sum = t.sum(axis).type(torch.float32)
    _ones = torch.ones_like(_sum, dtype=torch.float32)
    return torch.allclose(_sum, _ones)
    def _get_ll_contrib(
        self,
        generation_scores: torch.Tensor,
        generation_scores_mask: torch.BoolTensor,
        copy_scores: torch.Tensor,
        target_tokens: torch.Tensor,
        target_to_source: torch.Tensor,
        source_mask: torch.BoolTensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Get the log-likelihood contribution from a single timestep.

        # Parameters

        generation_scores : `torch.Tensor`
            Shape: `(batch_size, target_vocab_size)`
        generation_scores_mask : `torch.BoolTensor`
            Shape: `(batch_size, target_vocab_size)`. This is just a tensor of 1's.
        copy_scores : `torch.Tensor`
            Shape: `(batch_size, source_sequence_length)`
        target_tokens : `torch.Tensor`
            Shape: `(batch_size,)`
        target_to_source : `torch.Tensor`
            Shape: `(batch_size, source_sequence_length)`
        source_mask : `torch.BoolTensor`
            Shape: `(batch_size, source_sequence_length)`

        # Returns

        Tuple[torch.Tensor, torch.Tensor]
            Shape: `(batch_size,), (batch_size, source_sequence_length)`
        """
        _, target_size = generation_scores.size()

        # The point of this mask is to just mask out all source token scores
        # that just represent padding. We apply the mask to the concatenation
        # of the generation scores and the copy scores to normalize the scores
        # correctly during the softmax.
        # shape: (batch_size, target_vocab_size + source_sequence_length)
        mask = torch.cat((generation_scores_mask, source_mask), dim=-1)
        # shape: (batch_size, target_vocab_size + source_sequence_length)
        all_scores = torch.cat((generation_scores, copy_scores), dim=-1)
        # Normalize generation and copy scores.
        # shape: (batch_size, target_vocab_size + source_sequence_length)
        log_probs = util.masked_log_softmax(all_scores, mask)
        # Calculate the log probability (`copy_log_probs`) for each token in the source sentence
        # that matches the current target token. We use the sum of these copy probabilities
        # for matching tokens in the source sentence to get the total probability
        # for the target token. We also need to normalize the individual copy probabilities
        # to create `selective_weights`, which are used in the next timestep to create
        # a selective read state.
        # shape: (batch_size, source_sequence_length)
        copy_log_probs = (log_probs[:, target_size:] +
                          (target_to_source.to(log_probs.dtype) +
                           util.tiny_value_of_dtype(log_probs.dtype)).log())
        # Since `log_probs[:, target_size]` gives us the raw copy log probabilities,
        # we use a non-log softmax to get the normalized non-log copy probabilities.
        selective_weights = util.masked_softmax(log_probs[:, target_size:],
                                                target_to_source)
        # This mask ensures that item in the batch has a non-zero generation probabilities
        # for this timestep only when the gold target token is not OOV or there are no
        # matching tokens in the source sentence.
        # shape: (batch_size, 1)
        gen_mask = (target_tokens !=
                    self._oov_index) | (target_to_source.sum(-1) == 0)
        log_gen_mask = (
            gen_mask +
            util.tiny_value_of_dtype(log_probs.dtype)).log().unsqueeze(-1)
        # Now we get the generation score for the gold target token.
        # shape: (batch_size, 1)
        generation_log_probs = log_probs.gather(
            1, target_tokens.unsqueeze(1)) + log_gen_mask
        # ... and add the copy score to get the step log likelihood.
        # shape: (batch_size, 1 + source_sequence_length)
        combined_gen_and_copy = torch.cat(
            (generation_log_probs, copy_log_probs), dim=-1)
        # shape: (batch_size,)
        step_log_likelihood = util.logsumexp(combined_gen_and_copy)

        return step_log_likelihood, selective_weights
Beispiel #59
0
    def _get_entity_action_logits(
        self,
        state: WikiTablesDecoderState,
        actions_to_link: List[List[int]],
        attention_weights: torch.Tensor,
        linked_checklist_balance: torch.Tensor = None
    ) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.FloatTensor]:
        """
        Returns scores for each action in ``actions_to_link`` that are derived from the linking
        scores between the question and the table entities, and the current attention on the
        question.  The intuition is that if we're paying attention to a particular word in the
        question, we should tend to select entity productions that we think that word refers to.
        We additionally return a mask representing which elements in the returned ``action_logits``
        tensor are just padding, and an embedded representation of each action that can be used as
        input to the next step of the encoder.  That embedded representation is derived from the
        type of the entity produced by the action.

        The ``actions_to_link`` are in terms of the `batch` action list passed to
        ``model.forward()``.  We need to convert these integers into indices into the linking score
        tensor, which has shape (batch_size, num_entities, num_question_tokens), look up the
        linking score for each entity, then aggregate the scores using the current question
        attention.

        Parameters
        ----------
        state : ``WikiTablesDecoderState``
            The current state.  We'll use this to get the linking scores.
        actions_to_link : ``List[List[int]]``
            A list of _batch_ action indices for each group element.  Should have shape
            (group_size, num_actions), unpadded.  This is expected to be output from
            :func:`_get_actions_to_consider`.
        attention_weights : ``torch.Tensor``
            The current attention weights over the question tokens.  Should have shape
            ``(group_size, num_question_tokens)``.
        linked_checklist_balance : ``torch.Tensor``, optional (default=None)
            If the parser is being trained to maximize coverage over an agenda, this is the balance
            vector corresponding to entity actions, containing 1s and 0s, with 1s showing the
            actions that are yet to be produced. Required only if the parser is being trained to
            maximize coverage.

        Returns
        -------
        action_logits : ``torch.FloatTensor``
            A score for each of the given actions.  Shape is ``(group_size, num_actions)``, where
            ``num_actions`` is the maximum number of considered actions for any group element.
        action_mask : ``torch.LongTensor``
            A mask of shape ``(group_size, num_actions)`` indicating which ``(group_index,
            action_index)`` pairs were merely added as padding.
        type_embeddings : ``torch.LongTensor``
            A tensor of shape ``(group_size, num_actions, action_embedding_dim)``, with an embedded
            representation of the `type` of the entity corresponding to each action.
        """
        # First we map the actions to entity indices, using state.actions_to_entities, and find the
        # type of each entity using state.entity_types.
        action_entities: List[List[int]] = []
        entity_types: List[List[int]] = []
        for batch_index, action_list in zip(state.batch_indices,
                                            actions_to_link):
            action_entities.append([])
            entity_types.append([])
            for action_index in action_list:
                entity_index = state.actions_to_entities[(batch_index,
                                                          action_index)]
                action_entities[-1].append(entity_index)
                entity_types[-1].append(state.entity_types[entity_index])

        # Then we create a padded tensor suitable for use with
        # `state.flattened_linking_scores.index_select()`.
        num_actions = [len(action_list) for action_list in action_entities]
        max_num_actions = max(num_actions)
        padded_actions = [
            common_util.pad_sequence_to_length(action_list, max_num_actions)
            for action_list in action_entities
        ]
        padded_types = [
            common_util.pad_sequence_to_length(type_list, max_num_actions)
            for type_list in entity_types
        ]
        # Shape: (group_size, num_actions)
        action_tensor = state.score[0].new_tensor(padded_actions,
                                                  dtype=torch.long)
        type_tensor = state.score[0].new_tensor(padded_types, dtype=torch.long)

        # To get the type embedding tensor, we just use an embedding matrix on the list of entity
        # types.
        type_embeddings = self._entity_type_embedding(type_tensor)
        # `state.flattened_linking_scores` is shape (batch_size * num_entities, num_question_tokens).
        # We want to select from this using `action_tensor` to get a tensor of shape (group_size,
        # num_actions, num_question_tokens).  Unfortunately, the index_select functions in nn.util
        # don't do this operation.  So we'll do some reshapes and do the index_select ourselves.
        group_size = len(state.batch_indices)
        num_question_tokens = state.flattened_linking_scores.size(-1)
        flattened_actions = action_tensor.view(-1)
        # (group_size * num_actions, num_question_tokens)
        flattened_action_linking = state.flattened_linking_scores.index_select(
            0, flattened_actions)
        # (group_size, num_actions, num_question_tokens)
        action_linking = flattened_action_linking.view(group_size,
                                                       max_num_actions,
                                                       num_question_tokens)

        # Now we get action logits by weighting these entity x token scores by the attention over
        # the question tokens.  We can do this efficiently with torch.bmm.
        action_logits = action_linking.bmm(
            attention_weights.unsqueeze(-1)).squeeze(-1)
        if linked_checklist_balance is not None:
            # ``linked_checklist_balance`` is a binary tensor of size (group_size, num_actions) with
            # 1s indicating the linked actions that the agenda wants the decoder to produce, but
            # haven't been produced yet. We're simply doubling the logits of those actions here.
            action_logits_addition = action_logits * linked_checklist_balance
            action_logits = action_logits + self._linked_checklist_multiplier * action_logits_addition
        # Finally, we make a mask for our action logit tensor.
        sequence_lengths = action_linking.new_tensor(num_actions)
        action_mask = util.get_mask_from_sequence_lengths(
            sequence_lengths, max_num_actions)
        return action_logits, action_mask, type_embeddings
 def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
     return loss.backward()