Esempio n. 1
0
def sort_batch_by_length(tensor: torch.autograd.Variable,
                         sequence_lengths: torch.autograd.Variable):
    """
    Sort a batch first tensor by some specified lengths.

    Parameters
    ----------
    tensor : Variable(torch.FloatTensor), required.
        A batch first Pytorch tensor.
    sequence_lengths : Variable(torch.LongTensor), required.
        A tensor representing the lengths of some dimension of the tensor which
        we want to sort by.

    Returns
    -------
    sorted_tensor : Variable(torch.FloatTensor)
        The original tensor sorted along the batch dimension with respect to sequence_lengths.
    sorted_sequence_lengths : Variable(torch.LongTensor)
        The original sequence_lengths sorted by decreasing size.
    restoration_indices : Variable(torch.LongTensor)
        Indices into the sorted_tensor such that
        ``sorted_tensor.index_select(0, restoration_indices) == original_tensor``
    permuation_index : Variable(torch.LongTensor)
        The indices used to sort the tensor. This is useful if you want to sort many
        tensors using the same ordering.
    """

    if not isinstance(tensor, Variable) or not isinstance(
            sequence_lengths, Variable):
        raise ConfigurationError(
            "Both the tensor and sequence lengths must be torch.autograd.Variables."
        )

    sorted_sequence_lengths, permutation_index = sequence_lengths.sort(
        0, descending=True)
    sorted_tensor = tensor.index_select(0, permutation_index)

    # This is ugly, but required - we are creating a new variable at runtime, so we
    # must ensure it has the correct CUDA vs non-CUDA type. We do this by cloning and
    # refilling one of the inputs to the function.
    index_range = sequence_lengths.data.clone().copy_(
        torch.arange(0, len(sequence_lengths)))
    # This is the equivalent of zipping with index, sorting by the original
    # sequence lengths and returning the now sorted indices.
    index_range = Variable(index_range.long())
    _, reverse_mapping = permutation_index.sort(0, descending=False)
    restoration_indices = index_range.index_select(0, reverse_mapping)
    return sorted_tensor, sorted_sequence_lengths, restoration_indices, permutation_index
Esempio n. 2
0
def sort_batch_by_length(tensor: torch.autograd.Variable,
                         sequence_lengths: torch.autograd.Variable):

    if not isinstance(tensor, Variable) or not isinstance(
            sequence_lengths, Variable):
        raise Exception(
            "Both the tensor and sequence lengths must be torch.autograd.Variables."
        )

    sorted_sequence_lengths, permutation_index = sequence_lengths.sort(
        0, descending=True)
    sorted_tensor = tensor.index_select(0, permutation_index)

    index_range = sequence_lengths.data.clone().copy_(
        torch.arange(0, len(sequence_lengths)))
    index_range = Variable(index_range.long())
    _, reverse_mapping = permutation_index.sort(0, descending=False)
    restoration_indices = index_range.index_select(0, reverse_mapping)
    return sorted_tensor, sorted_sequence_lengths, restoration_indices, permutation_index
Esempio n. 3
0
def sort_batch_by_length(tensor: torch.autograd.Variable, sequence_lengths: torch.autograd.Variable):
    """
    Sort a batch first tensor by some specified lengths.

    Parameters
    ----------
    tensor : Variable(torch.FloatTensor), required.
        A batch first Pytorch tensor.
    sequence_lengths : Variable(torch.LongTensor), required.
        A tensor representing the lengths of some dimension of the tensor which
        we want to sort by.

    Returns
    -------
    sorted_tensor : Variable(torch.FloatTensor)
        The original tensor sorted along the batch dimension with respect to sequence_lengths.
    sorted_sequence_lengths : Variable(torch.LongTensor)
        The original sequence_lengths sorted by decreasing size.
    restoration_indices : Variable(torch.LongTensor)
        Indices into the sorted_tensor such that
        ``sorted_tensor.index_select(0, restoration_indices) == original_tensor``
    """

    if not isinstance(tensor, Variable) or not isinstance(sequence_lengths, Variable):
        raise ConfigurationError("Both the tensor and sequence lengths must be torch.autograd.Variables.")

    sorted_sequence_lengths, permutation_index = sequence_lengths.sort(0, descending=True)
    sorted_tensor = tensor.index_select(0, permutation_index)

    # This is ugly, but required - we are creating a new variable at runtime, so we
    # must ensure it has the correct CUDA vs non-CUDA type. We do this by cloning and
    # refilling one of the inputs to the function.
    index_range = sequence_lengths.data.clone().copy_(torch.arange(0, len(sequence_lengths)))
    # This is the equivalent of zipping with index, sorting by the original
    # sequence lengths and returning the now sorted indices.
    index_range = Variable(index_range.long())
    _, reverse_mapping = permutation_index.sort(0, descending=False)
    restoration_indices = index_range.index_select(0, reverse_mapping)
    return sorted_tensor, sorted_sequence_lengths, restoration_indices
Esempio n. 4
0
def sort_batch_by_length(tensor: torch.autograd.Variable,
                         sequence_lengths: torch.autograd.Variable):
    if not isinstance(tensor, Variable) or not isinstance(
            sequence_lengths, Variable):
        raise ValueError("Both the tensor and sequence lengths must "
                         "be torch.autograd.Variables.")

    sorted_sequence_lengths, permutation_index = sequence_lengths.sort(
        0, descending=True)
    sorted_tensor = tensor.index_select(0, permutation_index)

    # This is ugly, but required - we are creating a new variable at runtime, so we
    # must ensure it has the correct CUDA vs non-CUDA type. We do this by cloning and
    # refilling one of the inputs to the function.
    index_range = sequence_lengths.data.clone().copy_(
        torch.arange(0, len(sequence_lengths)))
    # This is the equivalent of zipping with index, sorting by the original
    # sequence lengths and returning the now sorted indices.
    index_range = Variable(index_range.long())
    _, reverse_mapping = permutation_index.sort(0, descending=False)
    restoration_indices = index_range.index_select(0, reverse_mapping)
    return sorted_tensor, sorted_sequence_lengths, restoration_indices, permutation_index
    def forward(self, embedded_sents: torch.autograd.Variable, pos: list,
                gold_heads: list, data_lengths: np.ndarray):
        """
        Returns the logits for the labels and the arcs
        :return:
        """

        # Sort lens (pos lens are the same)
        sorted_indices = np.argsort(-np.array(data_lengths)).astype(np.int64)

        # Extract sorted lens by index
        lengths = data_lengths[sorted_indices]
        max_len = lengths[0]

        embedded_sents = self.embedding_dropout(embedded_sents[:, :max_len])

        pos_t = torch.LongTensor(pos)[:, :max_len]
        embedded_pos = self.pos_embed(torch.autograd.Variable(pos_t))
        embedded_pos = self.embedding_dropout(embedded_pos)

        # Extract tensors ordered by len
        stacked_x = embedded_sents.index_select(
            dim=0,
            index=torch.autograd.Variable(
                torch.from_numpy(sorted_indices))).to(self.device)
        stacked_pos_x = embedded_pos.index_select(
            dim=0,
            index=torch.autograd.Variable(
                torch.from_numpy(sorted_indices))).to(self.device)

        # Apply dropout and when one is dropped scale the other
        mask_words = stacked_x - stacked_pos_x == stacked_x

        mask_pos = stacked_pos_x - stacked_x == stacked_pos_x

        stacked_x[mask_words] *= 2
        stacked_pos_x[mask_pos] *= 2

        stacked_x = torch.cat((stacked_x, stacked_pos_x), dim=2)

        stacked_x = nn.utils.rnn.pack_padded_sequence(
            stacked_x,
            torch.from_numpy(lengths).to(self.device),
            batch_first=True)

        x_lstm, _ = self.lstm(stacked_x,
                              self.init_hidden_trainable(len(embedded_sents)))

        x_lstm, _ = nn.utils.rnn.pad_packed_sequence(x_lstm, batch_first=True)

        # Reorder the batch
        x_lstm = x_lstm.index_select(dim=0,
                                     index=torch.autograd.Variable(
                                         torch.from_numpy(
                                             np.argsort(sorted_indices).astype(
                                                 np.int64)).to(self.device)))

        # NN scoring
        h_arc_dep = self.arc_dep_mlp(x_lstm)
        h_arc_dep = self.activation(h_arc_dep)
        h_arc_dep = self.dropout(h_arc_dep)

        h_arc_head = self.arc_head_mlp(x_lstm)
        h_arc_head = self.activation(h_arc_head)
        h_arc_head = self.dropout(h_arc_head)

        h_label_dep = self.label_dep_mlp(x_lstm)
        h_label_dep = self.activation(h_label_dep)
        h_label_dep = self.dropout(h_label_dep)

        h_label_head = self.label_head_mlp(x_lstm)
        h_label_head = self.activation(h_label_head)
        h_label_head = self.dropout(h_label_head)

        # Heads computation
        s_i_arc = biaffine(h_arc_dep,
                           self.W_arc,
                           h_arc_head,
                           self.device,
                           num_outputs=1,
                           bias_x=True)

        # Labels computation
        full_label_logits = biaffine(h_label_dep,
                                     self.W_label,
                                     h_label_head,
                                     self.device,
                                     num_outputs=self.N_CLASSES,
                                     bias_x=True,
                                     bias_y=True)

        if self.training:
            gold_heads_t = torch.LongTensor(gold_heads)[:, :max_len].to(
                self.device)
            m = (gold_heads_t == self.heads_vocab['<PAD>'])
            gold_heads_t[m] *= 0
            pred_arcs = gold_heads_t
        else:
            pred_arcs = s_i_arc.argmax(-1)

        # Gather label logits from predicted or gold heads
        pred_arcs = pred_arcs.unsqueeze(2).unsqueeze(
            3)  # [batch, sent_len, 1, 1]
        pred_arcs = pred_arcs.expand(
            -1, -1, -1,
            full_label_logits.size(-1))  # [batch, sent_len, 1, n_labels]
        selected_label_logits = torch.gather(
            full_label_logits, 2,
            pred_arcs).squeeze(2)  # [batch, n_labels, sent_len]

        return s_i_arc, selected_label_logits, max_len, data_lengths