Ejemplo n.º 1
0
    def collate_fn(batch):
        ps, qs, p_lens, q_lens, labels, p_chars, q_chars, p_pos, q_pos = zip(
            *batch)

        # batch-wise clip
        p_lens = LongTensor(p_lens)
        q_lens = LongTensor(q_lens)
        labels = LongTensor(labels)

        p_max_len = p_lens.max().item()
        q_max_len = q_lens.max().item()
        char_len = p_chars[0].shape[-1]  #NOTE: since p_chars is a tuple

        new_ps = torch.zeros(size=(len(batch), p_max_len)).long()
        new_qs = torch.zeros(size=(len(batch), q_max_len)).long()
        new_p_chars = torch.zeros(size=(len(batch), p_max_len,
                                        char_len)).long()
        new_q_chars = torch.zeros(size=(len(batch), q_max_len,
                                        char_len)).long()
        new_p_pos = torch.zeros(size=(len(batch), p_max_len)).long()
        new_q_pos = torch.zeros(size=(len(batch), q_max_len)).long()

        for i, (p, pc, pp, p_len) in enumerate(zip(ps, p_chars, p_pos,
                                                   p_lens)):
            new_ps[i, :p_len] = LongTensor(p[:p_len])
            new_p_chars[i, :p_len, :] = LongTensor(pc[:p_len, :])
            new_p_pos[i, :p_len] = LongTensor(pp[:p_len])
        for i, (q, qc, qp, q_len) in enumerate(zip(qs, q_chars, q_pos,
                                                   q_lens)):
            new_qs[i, :q_len] = LongTensor(q[:q_len])
            new_q_chars[i, :q_len, :] = LongTensor(qc[:q_len, :])
            new_q_pos[i, :q_len] = LongTensor(qp[:q_len])

        return new_ps, p_lens, new_qs, q_lens, labels, new_p_chars, new_q_chars, new_p_pos, new_q_pos
Ejemplo n.º 2
0
    def forward(self,
                input_: Union[Tuple[torch.Tensor, ...], List[torch.Tensor],
                              torch.Tensor],
                lengths: torch.LongTensor,
                max_length: torch.LongTensor,
                target: Union[Tuple[torch.Tensor, ...], List[torch.Tensor],
                              torch.Tensor] = None,
                return_kwargs: bool = True,
                **kwargs):

        org_lengths = lengths
        batch_size = len(lengths)

        num_valid_chunks = self._length_to_num_chunks(lengths)
        lengths_win = self._window_lengths(lengths)
        max_length_win = torch.tensor(self.window_size,
                                      dtype=lengths.dtype,
                                      device=lengths.device)

        if type(input_) in [tuple, list]:
            input_win = [self._window(i) for i in input_]
        else:
            input_win = self._window(input_)

        if target is not None:
            if type(target) in [tuple, list]:
                target_win = [self._window(t) for t in target]
            else:
                target_win = self._window(target)
            kwargs["target"] = target_win

        output, kwargs = self.model(input_win, lengths_win, max_length_win,
                                    **kwargs)
        lengths_win = kwargs['lengths']
        max_length_win = kwargs['max_length']

        if self.output_merge_type == ModelConfig.MERGE_TYPE_CAT:
            assert lengths.max() == max_length, "This module does not support "\
                "parallel training with MERGE_TYPE_CAT."
            # NOTE: The problem is that max_length_win will be longer than
            #       lengths_win on GPUs with the shorter sequences. In that case
            #       the last element of max_length_win is never passed to the
            #       model, thus we do not know how it is changed in length.

        lengths = self._merge_lengths(lengths_win, batch_size)

        # At this point the merge operation either reduces all chunks to a fixed
        # number of frames (e.g. mean, mul, etc.) or lengths.max() and
        # max_length are equal, so we can use it as the max_length.
        max_length = lengths.max()

        output = self._merge_outputs(output, num_valid_chunks)

        kwargs['lengths'] = lengths
        kwargs['max_length'] = max_length
        if return_kwargs:
            return output, kwargs
        else:
            return output
Ejemplo n.º 3
0
    def _prepare_seq_for_padding(self, sequence):

        # padding of words to give later to FOFE Encoding layer

        seq_lengths = LongTensor(list(map(len, sequence)))
        seq_tensor = Variable(torch.zeros(
            (len(sequence), seq_lengths.max()))).long()
        for idx, (seq, seqlen) in enumerate(zip(sequence, seq_lengths)):
            # for FOFE encoding we need padding on both sentence and word level
            # on word level pad in front of word, on sentence level at the end of sentence
            seq_tensor[idx, seq_lengths.max()-seqlen:] = LongTensor(seq)

        return seq_tensor
Ejemplo n.º 4
0
    def _bert_encode(self,
                     data: torch.Tensor,
                     lengths: LongTensor,
                     bertmodel,
                     token_type_ids: Optional[torch.Tensor] = None,
                     attention_mask: Optional[torch.Tensor] = None,
                     position_ids: Optional[torch.Tensor] = None,
                     head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Uses an RNN and self-attention to encode a batch of sequences of word embeddings.
        :param batch: A FloatTensor of shape `(sequence_length, batch_size, embedding_size)` containing embedded text.
        :param lengths: A LongTensor of shape `(batch_size)` containing the lengths of the sequences, used for masking.
        :return: A FloatTensor of shape `(batch_size, output_size)` containing the encoding for each sequence
        in the batch.
        """
        # print(data.shape)
        # Create mask for padding
        max_len = lengths.max().item()
        attention_mask = torch.zeros(len(data), max_len, dtype=torch.float)
        for i in range(len(data)):
            attention_mask[i, :lengths[i]] = 1

        # if attention_mask is None and self.padding_idx is not None:
        #     attention_mask = (data != self.padding_idx).float()

        data = data.to(self.device)
        attention_mask = attention_mask.to(self.device)

        outputs = bertmodel(data, attention_mask=attention_mask)
        output = outputs[1]
        return output
Ejemplo n.º 5
0
def _check_points_to_volumes_inputs(
    points_3d: torch.Tensor,
    points_features: torch.Tensor,
    volume_densities: torch.Tensor,
    volume_features: torch.Tensor,
    grid_sizes: torch.LongTensor,
    mask: Optional[torch.Tensor] = None,
):

    # pyre-fixme[16]: `Tuple` has no attribute `values`.
    max_grid_size = grid_sizes.max(dim=0).values
    if torch.prod(max_grid_size) > volume_densities.shape[1]:
        raise ValueError(
            "One of the grid sizes corresponds to a larger number"
            + " of elements than the number of elements in volume_densities."
        )

    _, n_voxels, density_dim = volume_densities.shape

    if density_dim != 1:
        raise ValueError("Only one-dimensional densities are allowed.")

    ba, n_points, feature_dim = points_features.shape

    if volume_features.shape[1] != feature_dim:
        raise ValueError(
            "volume_features have a different number of channels"
            + " than points_features."
        )

    if volume_features.shape[2] != n_voxels:
        raise ValueError(
            "volume_features have a different number of elements"
            + " than volume_densities."
        )
Ejemplo n.º 6
0
    def get_attention_mask(self, encoder_lengths: torch.LongTensor,
                           decoder_length: int):
        """
        Returns causal mask to apply for self-attention layer.

        Args:
            self_attn_inputs: Inputs to self attention layer to determine mask shape
        """
        # indices to which is attended
        attend_step = torch.arange(decoder_length, device=self.device)
        # indices for which is predicted
        predict_step = torch.arange(0, decoder_length,
                                    device=self.device)[:, None]
        # do not attend to steps to self or after prediction
        # todo: there is potential value in attending to future forecasts if they are made with knowledge currently
        #   available
        #   one possibility is here to use a second attention layer for future attention (assuming different effects
        #   matter in the future than the past)
        #   or alternatively using the same layer but allowing forward attention - i.e. only masking out non-available
        #   data and self
        decoder_mask = attend_step >= predict_step
        # do not attend to steps where data is padded
        encoder_mask = create_mask(encoder_lengths.max(), encoder_lengths)
        # combine masks along attended time - first encoder and then decoder
        mask = torch.cat(
            (
                encoder_mask.unsqueeze(1).expand(-1, decoder_length, -1),
                decoder_mask.unsqueeze(0).expand(encoder_lengths.size(0), -1,
                                                 -1),
            ),
            dim=2,
        )
        return mask
Ejemplo n.º 7
0
def create_mask(lengths: LongTensor, cuda: bool = False) -> ByteTensor:
    """
    Creates a mask from a tensor of sequence lengths to mask out padding with 1s for content and 0s for padding.

    Example:
        >>> lengths = LongTensor([3, 4, 2])
        >>> create_mask(lengths)
        tensor([[1, 1, 1],
                [1, 1, 1],
                [1, 1, 0],
                [0, 1, 0]], dtype=torch.uint8)

    :param lengths: A LongTensor of shape `(batch_size)` with the length of each sequence in the batch.
    :param cuda: A boolean indicating whether to move the mask to GPU.
    :return: A ByteTensor of shape `(sequence_length, batch_size)` with 1s for content and 0s for padding.
    """
    # Get sizes
    seq_len, batch_size = lengths.max(), lengths.size(0)

    # Create length and index masks
    length_mask = lengths.unsqueeze(0).repeat(seq_len, 1)  # (seq_len, batch_size)

    index_mask = torch.arange(seq_len, dtype=torch.long).unsqueeze(1).repeat(1, batch_size)  # (seq_len, batch_size)

    # Create mask
    mask = (index_mask < length_mask)

    # Move to GPU
    if cuda:
        mask = mask.cuda()

    return mask
Ejemplo n.º 8
0
def scatter_sort(
    src: Tensor,
    index: LongTensor,
    descending=False,
    dim_size=None,
    out: Optional[Tuple[Tensor, LongTensor]] = None,
) -> Tuple[Tensor, LongTensor]:
    if src.ndimension() > 1:
        raise ValueError("Only implemented for 1D tensors")

    if dim_size is None:
        dim_size = index.max() + 1

    if out is None:
        result_values = torch.empty_like(src)
        result_indexes = index.new_empty(src.shape)
    else:
        result_values, result_indexes = out

    sizes = (
        index.new_zeros(dim_size)
        .scatter_add_(dim=0, index=index, src=torch.ones_like(index))
        .tolist()
    )

    start = 0
    for size in sizes:
        end = start + size
        values, indexes = torch.sort(src[start:end], dim=0, descending=descending)
        result_values[start:end] = values
        result_indexes[start:end] = indexes + start
        start = end

    return result_values, result_indexes
Ejemplo n.º 9
0
def torch_is_in_1d(
    query_tensor: torch.LongTensor,
    test_tensor: Union[Collection[int], torch.LongTensor],
    max_id: Optional[int] = None,
    invert: bool = False,
) -> torch.BoolTensor:
    """
    Return a boolean mask with ``Q[i]`` in T.

    The method guarantees memory complexity of ``max(size(Q), size(T))`` and is thus, memory-wise, superior to naive
    broadcasting.

    :param query_tensor: shape: S
        The query Q.
    :param test_tensor:
        The test set T.
    :param max_id:
        A maximum ID. If not given, will be inferred.
    :param invert:
        Whether to invert the result.

    :return: shape: S
        A boolean mask.
    """
    # normalize input
    if not isinstance(test_tensor, torch.Tensor):
        test_tensor = torch.as_tensor(data=list(test_tensor), dtype=torch.long)
    if max_id is None:
        max_id = max(query_tensor.max(), test_tensor.max()) + 1
    mask = torch.zeros(max_id, dtype=torch.bool)
    mask[test_tensor] = True
    if invert:
        mask = ~mask
    return mask[query_tensor.view(-1)].view(*query_tensor.shape)
Ejemplo n.º 10
0
 def build_position_id(self, lengths: t.LongTensor):
     batch_size = lengths.size(0)
     max_length = lengths.max()
     device = lengths.device
     position_id = t.zeros(batch_size,
                           max_length,
                           device=device,
                           dtype=t.long)
     for index, value in enumerate(lengths):
         position_id[index][:value] = self.position[:value]
     return position_id
Ejemplo n.º 11
0
    def collate_fn(batch):
        ps, qs, p_lens, q_lens, labels = zip(*batch)

        # batch-wise clip
        p_lens = LongTensor(p_lens)
        q_lens = LongTensor(q_lens)
        labels = LongTensor(labels)

        p_max_len = p_lens.max().item()
        q_max_len = q_lens.max().item()

        new_ps = torch.zeros(size=(len(batch), p_max_len)).long()
        new_qs = torch.zeros(size=(len(batch), q_max_len)).long()

        for i, (p, p_len) in enumerate(zip(ps, p_lens)):
            new_ps[i, :p_len] = LongTensor(p[:p_len])
        for i, (q, q_len) in enumerate(zip(qs, q_lens)):
            new_qs[i, :q_len] = LongTensor(q[:q_len])

        return new_ps, p_lens, new_qs, q_lens, labels
Ejemplo n.º 12
0
def sequence_mask(lengths: torch.LongTensor, max_len: Optional[int] = None):
    """Create a boolean mask from sequence lengths.

    Arguments:
        lengths: lengths with shape (bs,)
        max_len: max sequence length; if None it will be set to lengths.max()
    """
    if max_len is None:
        max_len = lengths.max()
    # This is equivalent
    mask = torch.arange(max_len, device=lengths.device)[None, :] < lengths[:, None]
    return mask
Ejemplo n.º 13
0
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            tags: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of gold labels.  These can either be integer
            indexes or one hot arrays of labels, so of shape ``(batch_size, num_tokens)`` or of
            shape ``(batch_size, num_tokens, num_tags)``.

        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            unnormalised log probabilities of the tag classes.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        """
        embedded_text_input = self.text_field_embedder(tokens)
        batch_size = embedded_text_input.size()[0]
        encoded_text, _ = self.stacked_encoders(embedded_text_input)

        logits = self.tag_projection_layer(encoded_text)
        reshaped_log_probs = logits.view(-1, self.num_classes)
        class_probabilities = F.softmax(reshaped_log_probs).view(
            [batch_size, -1, self.num_classes])

        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities
        }

        if tags:
            # Negative log likelihood criterion takes integer labels, not one hot.
            if tags.dim() == 3:
                _, tags = tags.max(-1)
            loss = self.sequence_loss(reshaped_log_probs, tags.view(-1))
            output_dict["loss"] = loss

        return output_dict
    def _embed(self,
               index_batch: List[List[int]]) -> Tuple[FloatTensor, LongTensor]:
        """
        Embeds a batch of indices and returns a tensors with the embeddings and sequence lengths.

        :param batch: A list of lists of ints. Each element of the outer list is a single sequence in the batch.
        Each element of the inner list is a word index (as specified by `self.word_to_index`).
        :return: A tuple of (FloatTensor, LongTensor) with shapes `(sequence_length, batch_size, embedding_size)` and
        `(batch_size)` respectively. The FloatTensor contains the embedded batch, with zero padding added as necessary.
        The LongTensor contains the length of each sequence in the batch.
        """
        # Get lengths
        lengths = LongTensor([len(sequence) for sequence in index_batch])
        # TODO: switch to pytorch version 0.4.1 and add .item()
        seq_len = lengths.max(
        ) if torch.__version__[:3] == '0.3' else lengths.max().item()

        # Add padding
        index_batch = [
            sequence + [self.padding_idx] * (seq_len - len(sequence))
            for sequence in index_batch
        ]

        # Cast to tensor
        index_batch = LongTensor(index_batch)  # (batch_size, sequence_length)

        # Optionally move to GPU
        if self.is_cuda:
            index_batch = index_batch.cuda()

        # Get embeddings
        batch = self.embedding_layer(
            index_batch)  # (batch_size, sequence_length, embedding_size)

        # Transpose to (sequence_length, batch_size, embedding_size)
        batch = batch.transpose(0, 1)

        return batch, lengths
Ejemplo n.º 15
0
def scatter_topk(
    src: Tensor, index: LongTensor, k: int, num_chunks=None, fill_value=None
) -> Tuple[Tensor, LongTensor, LongTensor]:
    """

    Args:
        src:
        index: must be sorted in ascending order
        k:
        num_chunks:
        fill_value:

    Returns: A 1D tensor of shape [num_chunks * k]

    """
    if src.ndimension() > 1:
        raise ValueError("Only implemented for 1D tensors")

    if num_chunks is None:
        num_chunks = index.max().item() + 1

    if fill_value is None:
        fill_value = float("NaN")

    result_values = src.new_full((num_chunks * k,), fill_value=fill_value)
    result_indexes_whole = index.new_full((num_chunks * k,), fill_value=-1)
    result_indexes_within_chunk = index.new_full((num_chunks * k,), fill_value=-1)

    chunk_sizes = (
        index.new_zeros(num_chunks)
        .scatter_add_(dim=0, index=index, src=torch.ones_like(index))
        .tolist()
    )

    start = 0
    for chunk_idx, chunk_size in enumerate(chunk_sizes):
        chunk = src[start : start + chunk_size]
        values, indexes = torch.topk(chunk, k=min(k, chunk_size), dim=0)

        result_values[chunk_idx * k : chunk_idx * k + len(values)] = values
        result_indexes_within_chunk[
            chunk_idx * k : chunk_idx * k + len(indexes)
        ] = indexes
        result_indexes_whole[chunk_idx * k : chunk_idx * k + len(indexes)] = (
            indexes + start
        )

        start += chunk_size

    return result_values, result_indexes_whole, result_indexes_within_chunk
Ejemplo n.º 16
0
    def forward(self, inp: torch.FloatTensor, tgt: torch.LongTensor):
        if inp.size(0) != tgt.size(0):
            raise RuntimeError('Input and target should have the same size '
                               'in the batch dimension.')
        num_elms = 0
        entry_size = tgt.size(0)
        output = inp.new_zeros(entry_size)  # log probabilities
        gather_inds = tgt.new_empty(entry_size)  # tgt indices in head

        for i in range(self.n_clusters + 1):
            target_mask, rel_inds = \
                get_cluster_members(i, tgt, self.cutoffs, self.ent_slices)
            # members of the current cluster
            members = target_mask.nonzero().squeeze()
            if members.numel() == 0:
                continue
            if i == 0:  # Head cluster
                # Head cluster also needs to compute relative indices
                gather_inds.index_copy_(0, members, rel_inds[target_mask])
            else:  # Tail clusters including entity clusters
                cluster_index = self.cutoffs[0] + i - 1
                gather_inds.index_fill_(0, members, cluster_index)

                # Subset of input which elements should be in this cluster
                input_subset = inp.index_select(0, members)
                # Forward
                cluster_output = self.tail[i - 1](input_subset)
                cluster_logprob = F.log_softmax(cluster_output, dim=1)
                relative_target = rel_inds[target_mask]
                local_logprob = \
                    cluster_logprob.gather(1, relative_target.unsqueeze(1))
                output.index_copy_(0, members, local_logprob.squeeze(1))

            num_elms += members.numel()

        if num_elms != entry_size:
            logger.error('used_rows ({}) and batch_size ({}) does not match'
                         ''.format(num_elms, entry_size))
            raise RuntimeError("Target values should be in [0, {}], "
                               "but values in range [{}, {}] "
                               "were found. ".format(self.n_classes - 1,
                                                     tgt.min().item(),
                                                     tgt.max().item()))

        head_output = self.head(inp)
        head_logprob = F.log_softmax(head_output, dim=1)
        output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze()

        # return neglog
        return -output
Ejemplo n.º 17
0
def load_batch_input_to_memory(batch_input, has_targets=True):
    if has_targets:
        batch_input = [seq[0] for seq in batch_input]
    else:
        batch_input = batch_input

    # Get the length of each seq in your batch
    tensor_lens = LongTensor(
        list(map(lambda x: max((len(x[0]), len(x[1]))), batch_input)))

    # Zero-padded long-Matirx size of (B, T)
    tensor_seqs = zeros((len(batch_input), 2, tensor_lens.max(), 50)).float()
    for idx, seq in enumerate(batch_input):
        tensor_seqs[idx, :2, :tensor_lens[idx]] = FloatTensor(np.array(seq))

    return tensor_seqs, tensor_lens
    def encode(self,
               data: torch.Tensor,
               lengths: LongTensor,
               token_type_ids: Optional[torch.Tensor] = None,
               attention_mask: Optional[torch.Tensor] = None,
               position_ids: Optional[torch.Tensor] = None,
               head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Uses an RNN and self-attention to encode a batch of sequences of word embeddings.
        :param batch: A FloatTensor of shape `(sequence_length, batch_size, embedding_size)` containing embedded text.
        :param lengths: A LongTensor of shape `(batch_size)` containing the lengths of the sequences, used for masking.
        :return: A FloatTensor of shape `(batch_size, output_size)` containing the encoding for each sequence
        in the batch.
        """
        # print(data.shape)
        # Create mask for padding
        max_len = lengths.max().item()
        attention_mask = torch.zeros(len(data), max_len, dtype=torch.float)
        for i in range(len(data)):
            attention_mask[i, :lengths[i]] = 1

        # if attention_mask is None and self.padding_idx is not None:
        #     attention_mask = (data != self.padding_idx).float()

        data = data.to(self.device)
        attention_mask = attention_mask.to(self.device)

        outputs = self.embedder(data, attention_mask=attention_mask)
        if self.pooling == 'first':
            output = outputs[1]
        elif self.pooling == 'average':
            output = outputs[0]
            output = output * attention_mask.unsqueeze(-1)
            output = output.mean(dim=1)

        if self.config['bert_freeze_all']:
            # print('freezing all bert weight')
            output = output.detach()
        # output = self.outlayer(output)
        return output
Ejemplo n.º 19
0
    def _window_lengths(self, lengths: torch.LongTensor) -> torch.LongTensor:
        lengths = lengths.clone()
        if lengths.ndim == 0:
            lengths.unsqueeze(0)

        lengths_win = []
        while True:
            sub_lengths = torch.clamp(lengths, max=self.window_size)
            lengths = torch.clamp(lengths - self.step, min=0)
            if sub_lengths.max() <= self.window_size - self.step:
                # Missing frames were only in overlap so part of last chunk.
                break

            lengths_win.append(sub_lengths)

            if not self.zero_padding and lengths.max() < self.window_size:
                break
        if len(lengths_win) == 0:
            # Edge case where all lengths are in zeroes overlap.
            return sub_lengths
        else:
            return torch.stack(lengths_win, dim=1).view(-1)
Ejemplo n.º 20
0
def merge_word_piece(output: torch.Tensor, word_pieces: List[Dict[int, int]],
                     lengths: torch.LongTensor) -> torch.Tensor:
    """
    piece: [ {start: l_span, ...}, {start: l_span, ...}, ...]
    lengths: [ l_1, l_2, ...]
    """
    representations, pad_len = list(), lengths.max().item()
    for i, pieces in enumerate(word_pieces):
        j, rep, out, origin_len = 0, list(), output[i], lengths[i].item()
        while len(rep) < origin_len:
            if j in pieces:
                rep_j = out[j:j + pieces[j]].mean(dim=0)
                j += pieces[j]
            else:
                rep_j = out[j]
                j += 1
            rep.append(rep_j)
        while len(rep) < pad_len:
            rep.append(torch.zeros_like(rep[0]))
        representations.append(torch.stack(rep))
    representations = torch.stack(representations)
    return representations
Ejemplo n.º 21
0
def softmax(
    src: torch.Tensor,
    index: torch.LongTensor,
    num_nodes: Union[None, int, torch.Tensor] = None,
    dim: int = 0,
) -> torch.Tensor:
    r"""
    Compute a sparsely evaluated softmax.

    Given a value tensor :attr:`src`, this function first groups the values
    along the given dimension based on the indices specified in :attr:`index`,
    and then proceeds to compute the softmax individually for each group.

    :param src:
        The source tensor.
    :param index:
        The indices of elements for applying the softmax.
    :param num_nodes:
        The number of nodes, i.e., :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
    :param dim:
        The dimension along which to compute the softmax.

    :return:
        The softmax-ed tensor.
    """
    if torch_scatter is None:
        raise ImportError(
            "torch-scatter is not installed, attention aggregation won't work. "
            "Install it here: https://github.com/rusty1s/pytorch_scatter", )
    num_nodes = num_nodes or index.max() + 1
    out = src.transpose(dim, 0)
    out = out - torch_scatter.scatter_max(
        out, index, dim=0, dim_size=num_nodes)[0][index]
    out = out.exp()
    out = out / torch_scatter.scatter_add(
        out, index, dim=0, dim_size=num_nodes)[index].clamp_min(1.0e-16)
    return out.transpose(0, dim)
Ejemplo n.º 22
0
##--------------------##
embed = Embedding(len(vocab), 4)  # embedding_dim = 4
lstm = LSTM(input_size=4, hidden_size=5,
            batch_first=True)  # input_dim = 4, hidden_dim = 5

## Step 4: Pad instances with 0s till max length sequence ##
##--------------------------------------------------------##

# get the length of each seq in your batch
seq_lengths = LongTensor(list(map(len, vectorized_seqs)))
# seq_lengths => [ 8, 4,  6]
# batch_sum_seq_len: 8 + 4 + 6 = 18
# max_seq_len: 8

seq_tensor = Variable(torch.zeros(
    (len(vectorized_seqs), seq_lengths.max()))).long()
# seq_tensor => [[0 0 0 0 0 0 0 0]
#                [0 0 0 0 0 0 0 0]
#                [0 0 0 0 0 0 0 0]]

for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):
    seq_tensor[idx, :seqlen] = LongTensor(seq)
# seq_tensor => [[ 6  9  8  4  1 11 12 10]          # long_str
#                [12  5  8 14  0  0  0  0]          # tiny
#                [ 7  3  2  5 13  7  0  0]]         # medium
# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8)

## Step 5: Sort instances by sequence length in descending order ##
##---------------------------------------------------------------##

seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
Ejemplo n.º 23
0
def paste(background: Tensor, patch: Tensor, x: LongTensor, y: LongTensor, mask: Optional[Tensor] = None):
    """
    Pastes the given patch into the background image tensor at the specified location.
    Optionally a mask of the same size as the patch can be passed in to blend the
    pasted contents with the background.

    :param background: A batch of image tensors of shape (B, C, H, W) that represent the background
    :param patch: A batch of image tensors of shape (B, C, h, w) which values get pasted into the background
    :param x: The horizontal integer coordinates relative to the top left corner of the background image.
        This tensor must be a one-dimensional tensor of shape (B, ).
    :param y: The vertical integer coordinates relative to the top left corner of the background image.
        This tensor must be a one-dimensional tensor of shape (B, ).
    :param mask: A mask of the same size as the patch that is used to blend foreground and background values.
        It is optional and defaults to ones (all is foreground).
    :return: The composite tensor of background and foreground values of shape (B, C, H, W).

    Note:
        1.  The X- and Y-coordinates can exceed the range of the background image (negative and positive).
            The background will be dynamically padded and cropped again after pasting such that the
            contents can go over the borders of the background image.
        2.  Currently it only supports integer locations.
        3.  All tensors must be on the same device.
    """
    # background: (B, C, H, W)
    # patch, mask: (B, C, h, w)
    # x, y: (B, )
    b, c, H, W = background.shape
    _, _, h, w = patch.shape
    mask = torch.ones_like(patch) if mask is None else mask
    device = background.device
    assert b == patch.size(0) == mask.size(0)
    assert b == x.size(0) == y.size(0)
    assert c == patch.size(1) == mask.size(1)
    assert h == mask.size(-2)
    assert w == mask.size(-1)
    assert 1 == x.ndimension() == y.ndimension()
    assert device == patch.device == x.device == y.device == mask.device
    x = x.long()
    y = y.long()

    # dynamically pad background for patches that go over borders
    left = min(x.min().abs().item(), 0)
    top = min(y.min().abs().item(), 0)
    right = max(x.max().item() + w - W, 0)
    bottom = max(y.max().item() + h - H, 0)
    background = nn.functional.pad(background, pad=[left, right, top, bottom])

    # generate indices
    gridb, gridc, gridy, gridx = torch.meshgrid(
        torch.arange(b, device=device),
        torch.arange(c, device=device),
        torch.arange(h, device=device),
        torch.arange(w, device=device)
    )
    x = x.view(b, 1, 1, 1).repeat(1, c, h, w)
    y = y.view(b, 1, 1, 1).repeat(1, c, h, w)
    x = x + gridx + left
    y = y + gridy + top

    # we need to ignore negative indices, or pasted conent will be rolled to the other side
    mask = mask * (x >= 0) * (y >= 0)
    # paste
    one = torch.tensor(1, dtype=mask.dtype)
    background[(gridb, gridc, y, x)] = mask * patch + (one - mask) * background[(gridb, gridc, y, x)]
    # crop away the padded regions
    background = background[..., top:(top + H), left:(left + W)]
    return background
Ejemplo n.º 24
0
def scatter_topk_2d_flat(
    src: Tensor, index: LongTensor, k: int, dim_size=None, fill_value=None
) -> Tuple[Tensor, Tuple[LongTensor, LongTensor], Tuple[LongTensor, LongTensor]]:
    """Finds the top k values in a 2D array partitioned along the dimension 0.

    ::

        +-----------------------+
        |          X            |
        |  X                    |
        |              X        |
        |     X                 |
        +-----------------------+
        |                       |
        |                 Y     |
        |       Y               |              +-------+
        |                       |              |X X X X|
        |                       |    top 4     +-------+
        |                       |  -------->   |X X X X|
        |                       |              +-------+
        |             Y         |              |Z Z Z Z|
        |                       |              +-------+
        |   Y                   |
        |                       |
        +-----------------------+
        |                       |
        |     Z       Z         |
        |                       |
        |        Z        Z     |
        |                       |
        +-----------------------+


    Args:
        src:
        index:
        k:
        dim_size:
        fill_value:

    Returns:

    """
    if src.ndimension() != 2:
        raise ValueError("Only implemented for 2D tensors")

    if dim_size is None:
        dim_size = index.max().item() + 1

    if fill_value is None:
        fill_value = float("NaN")

    ncols = src.shape[1]

    result_values = src.new_full((dim_size, k), fill_value=fill_value)
    result_indexes_whole_0 = index.new_full((dim_size, k), fill_value=-1)
    result_indexes_whole_1 = index.new_full((dim_size, k), fill_value=-1)
    result_indexes_within_chunk_0 = index.new_full((dim_size, k), fill_value=-1)
    result_indexes_within_chunk_1 = index.new_full((dim_size, k), fill_value=-1)

    chunk_sizes = (
        index.new_zeros(dim_size)
        .scatter_add_(dim=0, index=index, src=torch.ones_like(index))
        .tolist()
    )

    start_src = 0
    for chunk_idx, chunk_size in enumerate(chunk_sizes):
        flat_chunk = src[start_src : start_src + chunk_size, :].flatten()
        flat_values, flat_indexes = torch.topk(
            flat_chunk, k=min(k, chunk_size * ncols), dim=0
        )
        result_values[chunk_idx, : len(flat_values)] = flat_values

        indexes_0 = flat_indexes / ncols
        indexes_1 = flat_indexes % ncols
        result_indexes_within_chunk_0[chunk_idx, : len(flat_indexes)] = indexes_0
        result_indexes_within_chunk_1[chunk_idx, : len(flat_indexes)] = indexes_1

        result_indexes_whole_0[chunk_idx, : len(flat_indexes)] = indexes_0 + start_src
        result_indexes_whole_1[chunk_idx, : len(flat_indexes)] = indexes_1

        start_src += chunk_size

    return (
        result_values,
        (result_indexes_whole_0, result_indexes_whole_1),
        (result_indexes_within_chunk_0, result_indexes_within_chunk_1),
    )
Ejemplo n.º 25
0
    def forward(  # type: ignore
            self,
            tokens: TextFieldTensors,
            verb_indicator: torch.Tensor,
            sentence_end: torch.LongTensor,
            metadata: List[Any],
            tags: torch.LongTensor = None,
            offsets: torch.LongTensor = None):
        """
        # Parameters

        tokens : `TextFieldTensors`, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which
            indexes wordpieces from the BERT vocabulary.
        verb_indicator: `torch.LongTensor`, required.
            An integer `SequenceFeatureField` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : `torch.LongTensor`, optional (default = `None`)
            A torch tensor representing the sequence of integer gold class labels
            of shape `(batch_size, num_tokens)`
        metadata : `List[Dict[str, Any]]`, optional, (default = `None`)
            metadata containing the original words in the sentence, the verb to compute the
            frame for, and start offsets for converting wordpieces back to a sequence of words,
            under 'words', 'verb' and 'offsets' keys, respectively.

        # Returns

        An output dictionary consisting of:
        logits : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            a distribution of the tag classes per word.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """

        if isinstance(self.bert_model,
                      PretrainedTransformerMismatchedEmbedder):
            encoder_inputs = tokens["tokens"]
            if self.bert_config.type_vocab_size > 1:
                encoder_inputs["type_ids"] = verb_indicator
            encoded_text = self.bert_model(**encoder_inputs)
            batch_size = encoded_text.shape[0]
            if self.bert_config.type_vocab_size == 1:
                verb_embeddings = encoded_text[
                    torch.arange(batch_size).to(encoded_text.device),
                    verb_indicator.argmax(1), :]
                verb_embeddings = torch.where(
                    (verb_indicator.sum(1, keepdim=True) > 0).repeat(
                        1, verb_embeddings.shape[-1]), verb_embeddings,
                    torch.zeros_like(verb_embeddings))
                encoded_text = torch.cat(
                    (encoded_text, verb_embeddings.unsqueeze(1).repeat(
                        1, encoded_text.shape[1], 1)),
                    dim=2)
            mask = tokens["tokens"]["mask"]
            index = mask.sum(1).argmax().item()
            # print(mask.shape, encoded_text.shape, tokens["tokens"]["token_ids"].shape, tags.shape, max([len(x['words']) for x in metadata]), mask.sum(1)[index].item())
            # print(tokens["tokens"]["token_ids"][index,:])
        else:
            mask = get_text_field_mask(tokens)
            bert_embeddings, _ = self.bert_model(
                input_ids=util.get_token_ids_from_text_field_tensors(tokens),
                # token_type_ids=verb_indicator,
                attention_mask=mask,
            )

            batch_size, _ = mask.size()
            embedded_text_input = self.embedding_dropout(bert_embeddings)
            # Restrict to sentence part
            sentence_mask = (torch.arange(mask.shape[1]).unsqueeze(0).repeat(
                batch_size, 1).to(mask.device) <
                             sentence_end.unsqueeze(1).repeat(
                                 1, mask.shape[1])).long()
            cutoff = sentence_end.max().item()
            if self._encoder is None:
                encoded_text = embedded_text_input
                mask = sentence_mask[:, :cutoff].contiguous()
                encoded_text = encoded_text[:, :cutoff, :]
                tags = tags[:, :cutoff].contiguous()
            else:
                predicate_embeddings = self.predicate_embedding(verb_indicator)
                encoder_inputs = torch.cat(
                    (embedded_text_input, predicate_embeddings), dim=-1)
                encoded_text = self._encoder(encoder_inputs,
                                             mask=sentence_mask.bool())
                # print(verb_indicator)
                predicate_index = (verb_indicator * torch.arange(
                    start=verb_indicator.shape[-1] - 1, end=-1,
                    step=-1).to(mask.device).unsqueeze(0).repeat(
                        batch_size, 1)).argmax(1)
                # print(predicate_index)
                predicate_hidden = encoded_text[
                    torch.arange(batch_size).to(mask.device), predicate_index]
                predicate_exists, _ = verb_indicator.max(1)
                encoded_text = encoded_text[:, :cutoff, :]
                tags = tags[:, :cutoff].contiguous()
                mask = sentence_mask[:, :cutoff].contiguous()
                predicate_exists = predicate_exists.unsqueeze(1).repeat(
                    1, encoded_text.shape[-1])
                predicate_hidden = torch.where(
                    predicate_exists > 0, predicate_hidden,
                    torch.zeros_like(predicate_hidden))
                encoded_text = torch.cat(
                    (encoded_text, predicate_hidden.unsqueeze(1).repeat(
                        1, encoded_text.shape[1], 1)),
                    dim=-1)

        sequence_length = encoded_text.shape[1]
        logits = self.tag_projection_layer(encoded_text)
        # print(mask, logits)
        if self._lp and sequence_length <= 100:
            eps = 1e-4
            Q = eps * torch.eye(
                sequence_length * self.num_classes,
                sequence_length * self.num_classes).unsqueeze(0).repeat(
                    batch_size, 1, 1).to(logits.device).float()
            p = logits.view(batch_size, -1)
            G = -1 * torch.eye(
                sequence_length * self.num_classes).unsqueeze(0).repeat(
                    batch_size, 1, 1).to(logits.device).float()
            h = torch.zeros_like(p)
            A = torch.arange(sequence_length *
                             self.num_classes).unsqueeze(0).repeat(
                                 sequence_length, 1)
            A2 = torch.arange(sequence_length).unsqueeze(1).repeat(
                1, sequence_length * self.num_classes) * self.num_classes
            A = torch.where((A >= A2) & (A < A2 + self.num_classes),
                            torch.ones_like(A), torch.zeros_like(A))
            A = A.unsqueeze(0).repeat(batch_size, 1,
                                      1).to(logits.device).float()
            b = torch.ones_like(A[:, :, 0])
            probs = QPFunction()(Q, p, torch.autograd.Variable(torch.Tensor()),
                                 torch.autograd.Variable(torch.Tensor()), A, b)
            probs = probs.view(batch_size, sequence_length, self.num_classes)
            """logits_shape = logits.shape
            logits = torch.where(mask.bool().unsqueeze(-1).repeat(1, 1, logits.shape[-1]), logits, logits-10000)
            max_sequence_length = min([l for l in self.lengths if l >= sequence_length])
            if max_sequence_length > logits_shape[1]:
                logits = torch.cat((logits, torch.zeros((batch_size, max_sequence_length-logits_shape[1], logits_shape[2])).to(logits.device)), dim=1)
            lp_layer = self._layer_list[self.length_map[max_sequence_length]]
            probs, = lp_layer(logits)
            print(torch.isnan(probs).any())
            if max_sequence_length > logits_shape[1]:
                probs = probs[:,:logits_shape[1],:]"""
            logits = (torch.nn.functional.relu(probs) + 1e-4).log()
        if self._lpsmap:
            if self._lpsmap_core_only:
                all_logits = logits
            else:
                all_logits = torch.cat((logits, 0.5 * torch.ones(
                    (batch_size, 1, logits.shape[-1])).to(logits.device)),
                                       dim=1)
            probs = []
            for i in range(batch_size):
                if self.constrain_crf_decoding:
                    unaries = logits[i, :, :].view(-1).cpu()
                    additionals = self.crf.transitions.view(-1).repeat(
                        sequence_length) + 10000 * (
                            self.crf._constraint_mask[:-2, :-2] -
                            1).view(-1).repeat(sequence_length)
                    start_transitions = self.crf.start_transitions + 10000 * (
                        self.crf._constraint_mask[-2, :-2] - 1)
                    end_transitions = self.crf.start_transitions + 10000 * (
                        self.crf._constraint_mask[-1, :-2] - 1)
                    additionals = torch.cat(
                        (additionals, start_transitions, end_transitions),
                        dim=0).cpu()
                    fg = TorchFactorGraph()
                    x = fg.variable_from(unaries)
                    f = PFactorSequence()

                    f.initialize(
                        [self.num_classes for _ in range(sequence_length)])
                    factor = TorchOtherFactor(f, x, additionals)
                    fg.add(factor)
                    # add budget constraint for each state
                    for state in self._core_roles:
                        vars_state = x[state::self.num_classes]
                        fg.add(AtMostOne(vars_state))
                    # solve SparseMAP
                    fg.solve(max_iter=200)
                    probs.append(
                        unaries.to(logits.device).view(sequence_length,
                                                       self.num_classes))
                else:
                    fg = TorchFactorGraph()
                    x = fg.variable_from(all_logits[i, :, :].cpu())
                    for j in range(sequence_length):
                        fg.add(Xor(x[j, :]))
                    for j in self._core_roles:
                        fg.add(AtMostOne(x[:sequence_length, j]))
                    if not self._lpsmap_core_only:
                        full_sequence = list(range(sequence_length))
                        base_roles = set([
                            second
                            for (_, second) in self._r_roles + self._c_roles
                        ])
                        """for (r_role, base_role) in self._r_roles+self._c_roles:
                            for j in range(sequence_length):
                                fg.add(Imply(x[full_sequence+[j],[base_role]*sequence_length+[r_role]], negated=[True]*(sequence_length+1)))"""
                        for base_role in base_roles:
                            fg.add(OrOut(x[:, base_role]))
                        for (r_role,
                             base_role) in self._r_roles + self._c_roles:
                            fg.add(OrOut(x[:, r_role]))
                            fg.add(
                                Or(x[[sequence_length, sequence_length],
                                     [r_role, base_role]],
                                   negated=[True, False]))
                    max_iter = 100
                    if not self._lpsmap_core_only:
                        max_iter = min(max_iter, 400)
                    elif (not self.training) and not self._val_inference:
                        max_iter = min(max_iter, 200)
                    fg.solve(max_iter=max_iter)
                    probs.append(x.value[:sequence_length, :].contiguous().to(
                        logits.device))
            class_probabilities = torch.stack(probs)
            # class_probabilities = self.lpsmap(logits)
            max_seq_length = 200
            # if self.lpsmap is None:
            """with torch.no_grad():
                # self.lpsmap = LpSparseMap(num_rows=sequence_length, num_cols=self.num_classes, batch_size=batch_size, device=logits.device, constraints=[('xor', ('row', list(range(sequence_length)))), ('budget', ('col', self._core_roles))])
                max_iter = 1000
                constraint_types = ["xor", "budget"]
                constraint_dims = ["row", "col"]
                constraint_sets = [list(range(sequence_length)), self._core_roles]
                class_probabilities = lpsmap(logits, constraint_types, constraint_dims, constraint_sets, max_iter)
                # if max_seq_length > sequence_length:
                #     logits = torch.cat((logits, -9999.*torch.ones((batch_size, max_seq_length-sequence_length, self.num_classes)).to(logits.device)), dim=1)
                # class_probabilities = self.lpsmap.solve(logits, max_iter=max_iter)"""
            # logits = (class_probabilities+1e-4).log()
        else:
            reshaped_log_probs = logits.view(-1, self.num_classes)
            class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(
                [batch_size, sequence_length, self.num_classes])
        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities
        }
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.make_output_human_readable.
        output_dict["mask"] = mask
        # We add in the offsets here so we can compute the un-wordpieced tags.
        words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"])
                                      for x in metadata])
        output_dict["words"] = list(words)
        output_dict["verb"] = list(verbs)
        output_dict["wordpiece_offsets"] = list(offsets)

        if tags is not None:
            # print(mask.shape, tags.shape, logits.shape, tags.max(), tags.min())
            if self._lpsmap:
                loss = LpsmapLoss.apply(logits, class_probabilities, tags,
                                        mask)
                # tags_1hot = torch.zeros_like(class_probabilities).scatter_(2, tags.unsqueeze(-1), torch.ones_like(class_probabilities))
                # loss = -(tags_1hot*class_probabilities*mask.unsqueeze(-1).repeat(1, 1, class_probabilities.shape[-1])).sum()
            elif self.constrain_crf_decoding:
                loss = -self.crf(logits, tags, mask)
            else:
                loss = sequence_cross_entropy_with_logits(
                    logits, tags, mask, label_smoothing=self._label_smoothing)
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from make_output_human_readable()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of make_output_human_readable() to a separate function.
                batch_bio_predicted_tags = self.make_output_human_readable(
                    output_dict).pop("tags")
                from allennlp_models.structured_prediction.models.srl import (
                    convert_bio_tags_to_conll_format, )

                if self.constrain_crf_decoding and not self._lpsmap:
                    batch_conll_predicted_tags = [
                        convert_bio_tags_to_conll_format([
                            self.vocab.get_token_from_index(
                                tag, namespace=self._label_namespace)
                            for tag in seq
                        ]) for (seq, _) in self.crf.viterbi_tags(logits, mask)
                    ]
                else:
                    batch_conll_predicted_tags = [
                        convert_bio_tags_to_conll_format(tags)
                        for tags in batch_bio_predicted_tags
                    ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                # print(batch_bio_gold_tags)
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            output_dict["loss"] = loss
            output_dict["gold_tags"] = [x["gold_tags"] for x in metadata]
        return output_dict
Ejemplo n.º 26
0
    def forward(self,  # type: ignore
                # words: Dict[str, torch.LongTensor],
                encoded_text: torch.FloatTensor,
                mask: torch.LongTensor,
                pos_logits: torch.LongTensor = None,  # predicted
                head_tags: torch.LongTensor = None,
                head_indices: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size, _, _ = encoded_text.size()

        pos_tags = None
        if pos_logits is not None and self.pos_tag_embedding is not None:
            # Embed the predicted POS tags and concatenate the embeddings to the input
            num_pos_classes = pos_logits.size(-1)
            pos_logits = pos_logits.view(-1, num_pos_classes)
            _, pos_tags = pos_logits.max(-1)

            pos_embed_size = self.pos_tag_embedding.get_output_dim()
            embedded_pos_tags = self.dropout(self.pos_tag_embedding(pos_tags))
            embedded_pos_tags = embedded_pos_tags.view(batch_size, -1, pos_embed_size)
            encoded_text = torch.cat([encoded_text, embedded_pos_tags], -1)

        encoded_text = self.encoder(encoded_text, mask)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation,
                                                                       child_tag_representation,
                                                                       attended_arcs,
                                                                       mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation,
                                                                    child_tag_representation,
                                                                    attended_arcs,
                                                                    mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=head_indices,
                                                    head_tags=head_tags,
                                                    mask=mask)
            loss = arc_nll + tag_nll

            evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
            # We calculate attachment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            self._attachment_scores(predicted_heads[:, 1:],
                                    predicted_head_tags[:, 1:],
                                    head_indices[:, 1:],
                                    head_tags[:, 1:],
                                    evaluation_mask)
        else:
            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=predicted_heads.long(),
                                                    head_tags=predicted_head_tags.long(),
                                                    mask=mask)
            loss = arc_nll + tag_nll

        output_dict = {
            "heads": predicted_heads,
            "head_tags": predicted_head_tags,
            "arc_loss": arc_nll,
            "tag_loss": tag_nll,
            "loss": loss,
            "mask": mask,
            "words": [meta["words"] for meta in metadata],
            # "pos": [meta["pos"] for meta in metadata]
        }

        return output_dict
Ejemplo n.º 27
0
seqs = ['long_str',  # len = 8
        'tiny',      # len = 4
        'medium']    # len = 6
vocab = ['<pad>']+sorted(set([char for seq in seqs for char in seq]))
#print(vocab)
vectorized_seqs = [[vocab.index(tok) for tok in seq] for seq in seqs]
#print(vectorized_seqs)

embed = nn.Embedding(len(vocab), 4)
lstm = nn.LSTM(4, 5, batch_first=True)

seq_lengths = LongTensor(list(map(len, vectorized_seqs)))

#print(seq_lengths)

seq_tensor = Variable(torch.zeros(len(vectorized_seqs), seq_lengths.max())).long()
for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):
    seq_tensor[idx, :seqlen] = LongTensor(seq)

#print(seq_tensor)

seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
seq_tensor = seq_tensor[perm_idx]
#print(seq_tensor)

embedded_seq_tensor = embed(seq_tensor)
#print(embedded_seq_tensor)
# 3x8xvocab.vocabx4  --> 3 x 8 x 4
packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True)
#print(packed_input.data.shape)
Ejemplo n.º 28
0
    def forward(self,
                input_ids: torch.LongTensor,
                offsets: torch.LongTensor = None,
                token_type_ids: torch.LongTensor = None,
                history_encoding: torch.LongTensor = None,
                turn_encoding: torch.LongTensor = None,
                scenario_encoding: torch.LongTensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        input_ids : ``torch.LongTensor``
            The (batch_size, ..., max_sequence_length) tensor of wordpiece ids.
        offsets : ``torch.LongTensor``, optional
            The BERT embeddings are one per wordpiece. However it's possible/likely
            you might want one per original token. In that case, ``offsets``
            represents the indices of the desired wordpiece for each original token.
            Depending on how your token indexer is configured, this could be the
            position of the last wordpiece for each token, or it could be the position
            of the first wordpiece for each token.

            For example, if you had the sentence "Definitely not", and if the corresponding
            wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
            would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
            If offsets are provided, the returned tensor will contain only the wordpiece
            embeddings at those positions, and (in particular) will contain one embedding
            per token. If offsets are not provided, the entire tensor of wordpiece embeddings
            will be returned.
        token_type_ids : ``torch.LongTensor``, optional
            If an input consists of two sentences (as in the BERT paper),
            tokens from the first sentence should have type 0 and tokens from
            the second sentence should have type 1.  If you don't provide this
            (the default BertIndexer doesn't) then it's assumed to be all 0s.
        """
        # pylint: disable=arguments-differ
        batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1)
        initial_dims = list(input_ids.shape[:-1])

        # The embedder may receive an input tensor that has a sequence length longer than can
        # be fit. In that case, we should expect the wordpiece indexer to create padded windows
        # of length `self.max_pieces` for us, and have them concatenated into one long sequence.
        # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..."
        # We can then split the sequence into sub-sequences of that length, and concatenate them
        # along the batch dimension so we effectively have one huge batch of partial sentences.
        # This can then be fed into BERT without any sentence length issues. Keep in mind
        # that the memory consumption can dramatically increase for large batches with extremely
        # long sentences.
        needs_split = full_seq_len > self.max_pieces
        last_window_size = 0
        if needs_split:
            input_ids = self.split_indices(input_ids)
            if token_type_ids is not None:
                token_type_ids = self.split_indices(token_type_ids)
            if history_encoding is not None:
                history_encoding = self.split_indices(history_encoding)
            if turn_encoding is not None:
                turn_encoding = self.split_indices(turn_encoding)
            if scenario_encoding is not None:
                scenario_encoding = self.split_indices(scenario_encoding)

        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        if history_encoding is None:
            history_encoding = torch.zeros_like(input_ids)
        if turn_encoding is None:
            turn_encoding = torch.zeros_like(input_ids)
        if scenario_encoding is None:
            scenario_encoding = torch.zeros_like(input_ids)

        input_mask = (input_ids != 0).long()

        # input_ids may have extra dimensions, so we reshape down to 2-d
        # before calling the BERT model and then reshape back at the end.
        all_encoder_layers, pooled_output = self.bert_model(
            input_ids=util.combine_initial_dims(input_ids),
            token_type_ids=util.combine_initial_dims(token_type_ids),
            history_encoding=util.combine_initial_dims(history_encoding),
            turn_encoding=util.combine_initial_dims(turn_encoding),
            scenario_encoding=util.combine_initial_dims(scenario_encoding),
            attention_mask=util.combine_initial_dims(input_mask))
        all_encoder_layers = torch.stack(all_encoder_layers)

        if needs_split:
            # First, unpack the output embeddings into one long sequence again
            unpacked_embeddings = torch.split(all_encoder_layers,
                                              batch_size,
                                              dim=1)
            unpacked_embeddings = torch.cat(unpacked_embeddings, dim=2)
            assert batch_size == 1 and token_type_ids.max() > 0
            num_question_tokens = token_type_ids[0].nonzero().size(0)
            select_indices = self.indices_to_select(full_seq_len,
                                                    num_question_tokens)
            initial_dims.append(len(select_indices))
            recombined_embeddings = unpacked_embeddings[:, :, select_indices]
        else:
            recombined_embeddings = all_encoder_layers

        # Recombine the outputs of all layers
        # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim)
        # recombined = torch.cat(combined, dim=2)
        input_mask = (recombined_embeddings != 0).long()

        if self._scalar_mix is not None:
            mix = self._scalar_mix(recombined_embeddings, input_mask)
        else:
            mix = recombined_embeddings[-1]

        # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)

        if offsets is None:
            # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim)
            dims = initial_dims if needs_split else input_ids.size()
            return util.uncombine_initial_dims(mix, dims)
        else:
            # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
            offsets2d = util.combine_initial_dims(offsets)
            # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
            zeros = torch.zeros(offsets2d.size(0),
                                1,
                                dtype=offsets2d.dtype,
                                device=offsets2d.device)
            offsets2d = torch.cat([zeros, offsets2d], dim=-1)
            # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length + 1)
            range_vector = util.get_range_vector(
                offsets2d.size(0), device=util.get_device_of(mix)).unsqueeze(1)
            # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length + 1)
            selected_embeddings = mix[range_vector, offsets2d]

            return util.uncombine_initial_dims(selected_embeddings,
                                               offsets.size())
Ejemplo n.º 29
0
# vectorized_data = [[6, 9, 8, 4, 1, 11, 12, 10], [12, 5, 8, 14], [7, 3, 2, 5, 13, 7]]

# step 3 : define model

# input for embedding layer is lengths of inputs
# output for embedding layer is embedding shape of inputs
embedding_layer = nn.Embedding(len(vocab), 4)

# input_size is the embedding output size
# hidden_size is the hidden size of lstm
lstm = nn.LSTM(input_size=4, hidden_size=5, batch_first=True)

# step 4 : prepare data, by padding with 0 (<pad> token), making the batch equal lengths
seq_lengths = LongTensor([len(seq) for seq in vectorized_data])
sequence_tensor = Variable(
    torch.zeros(len(vectorized_data), seq_lengths.max(), dtype=torch.long))

for idx, (seq, seq_len) in enumerate(zip(vectorized_data, seq_lengths)):
    sequence_tensor[idx, :seq_len] = LongTensor(seq)

# sequence_tensor = ([[ 6,  9,  8,  4,  1, 11, 12, 10],
#                     [12,  5,  8, 14,  0,  0,  0,  0],
#                     [ 7,  3,  2,  5, 13,  7,  0,  0]])

# step 5 : sort the data in the batch in descending order by their original lengths
# seq_lengths = [8, 4, 6]
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
# seq_lengths = [8, 6, 4]
# perm_idx = [0, 2, 1]

sequence_tensor = sequence_tensor[perm_idx]
Ejemplo n.º 30
0
    def __init__(
        self,
        assignment: torch.LongTensor,
        token_representation: HintOrType[Representation] = None,
        token_representation_kwargs: OptionalKwargs = None,
        **kwargs,
    ) -> None:
        """
        Initialize the tokenization.

        :param assignment: shape: `(n, num_chosen_tokens)`
            the token assignment.
        :param token_representation: shape: `(num_total_tokens, *shape)`
            the token representations
        :param token_representation_kwargs:
            additional keyword-based parameters
        :param kwargs:
            additional keyword-based parameters passed to super.__init__
        """
        # needs to be lazily imported to avoid cyclic imports
        from . import representation_resolver

        # fill padding (nn.Embedding cannot deal with negative indices)
        padding = assignment < 0
        # sometimes, assignment.max() does not cover all relations (eg, inductive inference graphs
        # contain a subset of training relations) - for that, the padding index is the last index of the Representation
        self.vocabulary_size = (
            token_representation.max_id
            if isinstance(token_representation,
                          Representation) else assignment.max().item() +
            2  # exclusive (+1) and including padding (+1)
        )

        assignment[
            padding] = self.vocabulary_size - 1  # = assignment.max().item() + 1
        max_id, num_chosen_tokens = assignment.shape

        # resolve token representation
        token_representation = representation_resolver.make(
            token_representation,
            token_representation_kwargs,
            max_id=self.vocabulary_size,
        )
        super().__init__(max_id=max_id,
                         shape=(num_chosen_tokens, ) +
                         token_representation.shape,
                         **kwargs)

        # input validation
        if token_representation.max_id < self.vocabulary_size:
            raise ValueError(
                f"The token representations only contain {token_representation.max_id} representations,"
                f"but there are {self.vocabulary_size} tokens in use.", )
        elif token_representation.max_id > self.vocabulary_size:
            logger.warning(
                f"Token representations do contain more representations ({token_representation.max_id}) "
                f"than tokens are used ({self.vocabulary_size}).", )
        # register as buffer
        self.register_buffer(name="assignment", tensor=assignment)
        # assign sub-module
        self.vocabulary = token_representation