예제 #1
0
    def hsic_score(u: torch.autograd.Variable,
                   v: torch.autograd.Variable) -> float:
        """Calculate independence score.

                Args:
                    u (torch.autograd.Variable): Residuals.
                    v (torch.autograd.Variable): Test Data.
                Returns:
                    float: Independence score.
        """

        x = u.numpy()
        y = v.numpy()

        n = x.shape[0]

        h = np.identity(n) - np.ones((n, n), dtype=float) / n

        k = HilbertSchmidtInformationCriterion.rbf_dot(
            x, HilbertSchmidtInformationCriterion.get_width(x, n))
        l = HilbertSchmidtInformationCriterion.rbf_dot(
            y, HilbertSchmidtInformationCriterion.get_width(y, n))

        k_c = np.dot(np.dot(h, k), h)
        l_c = np.dot(np.dot(h, l), h)

        test_stat = np.sum(k_c.T * l_c) / n

        return test_stat
예제 #2
0
def get_dropout_mask(dropout_probability: float, tensor_for_masking: torch.autograd.Variable):
    """
    Computes and returns an element-wise dropout mask for a given tensor, where
    each element in the mask is dropped out with probability dropout_probability.
    Note that the mask is NOT applied to the tensor - the tensor is passed to retain
    the correct CUDA tensor type for the mask.

    Parameters
    ----------
    dropout_probability : float, required.
        Probability of dropping a dimension of the input.
    tensor_for_masking : torch.Variable, required.


    Returns
    -------
    A torch.FloatTensor consisting of the binary mask scaled by 1/ (1 - dropout_probability).
    This scaling ensures expected values and variances of the output of applying this mask
     and the original tensor are the same.
    """
    binary_mask = tensor_for_masking.clone()
    binary_mask.data.copy_(torch.rand(tensor_for_masking.size()) > dropout_probability)
    # Scale mask by 1/keep_prob to preserve output statistics.
    dropout_mask = binary_mask.float().div(1.0 - dropout_probability)
    return dropout_mask
예제 #3
0
파일: util.py 프로젝트: xumx/allennlp
def get_dropout_mask(dropout_probability: float, tensor_for_masking: torch.autograd.Variable):
    """
    Computes and returns an element-wise dropout mask for a given tensor, where
    each element in the mask is dropped out with probability dropout_probability.
    Note that the mask is NOT applied to the tensor - the tensor is passed to retain
    the correct CUDA tensor type for the mask.

    Parameters
    ----------
    dropout_probability : float, required.
        Probability of dropping a dimension of the input.
    tensor_for_masking : torch.Variable, required.


    Returns
    -------
    A torch.FloatTensor consisting of the binary mask scaled by 1/ (1 - dropout_probability).
    This scaling ensures expected values and variances of the output of applying this mask
     and the original tensor are the same.
    """
    binary_mask = tensor_for_masking.clone()
    binary_mask.data.copy_(torch.rand(tensor_for_masking.size()) > dropout_probability)
    # Scale mask by 1/keep_prob to preserve output statistics.
    dropout_mask = binary_mask.float().div(1.0 - dropout_probability)
    return dropout_mask
예제 #4
0
    def forward(self, x: torch.autograd.Variable):
        # x's shape must be [b, self.in_size]
        assert x.size(1) == self.in_size

        # (b, in_size)

        x = x.view(x.size(0), self.in_size, 1, 1)
        x = self.deconv(x)

        # (b, out_channels, 128, 128)

        return x
예제 #5
0
 def forward(self, net_output: torch.autograd.Variable,
             y: torch.Tensor) -> torch.autograd.Variable:
     """
     Args:
         net_output: Predicted values, size should be [n, 2], n is the number of records.
         y: True values
     Returns: loss function variable
     """
     param_mu = net_output.t()[0]
     param_sigma = net_output.t()[1]
     predictive_dist = Normal(loc=param_mu, scale=param_sigma)
     log_likelihoods = predictive_dist.log_prob(y)
     return -log_likelihoods.sum()
예제 #6
0
파일: rnn.py 프로젝트: Mrpatekful/translate
    def forward(self,
                encoder_outputs: torch.autograd.Variable,
                hidden_state: torch.autograd.Variable,
                targets: torch.autograd.Variable = None,
                max_length: int = None) -> tuple:
        """
        Forward step of the attentional decoder unit. If the targets parameter is not None, then teacher forcing is
        used, so during decoding, the previous output word will be provided at time step t. If targets is None, decoding
        follows the general method, when the input word for the recurrent unit at time step t, is the output word at
        time step t-1.

        :param targets:
            Variable, (batch_size, sequence_length) a batch of word ids.

        :param max_length:
            int, maximum length for the decoded sequence. If None the max_length parameter's
            value will be used.

        :param encoder_outputs:
            Variable, with size of (batch_size, sequence_length, hidden_size).

        :param hidden_state:
            Variable, (num_layers * directions, batch_size, hidden_size) initial hidden state.

        :return outputs: dict, containing three string keys, symbols: Ndarray, the decoded word ids,
                 alignment_weights:
        """
        batch_size = encoder_outputs.size(0)
        input_sequence_length = encoder_outputs.size(1)

        if targets is not None:

            predictions = self._forced_decode(
                targets=targets,
                batch_size=batch_size,
                hidden_state=hidden_state,
                encoder_outputs=encoder_outputs,
                input_sequence_length=input_sequence_length)

        else:

            predictions = self._predictive_decode(
                max_length=max_length,
                batch_size=batch_size,
                hidden_state=hidden_state,
                encoder_outputs=encoder_outputs,
                input_sequence_length=input_sequence_length)

        return self._outputs, predictions
예제 #7
0
def sort_batch_by_length(tensor: torch.autograd.Variable,
                         sequence_lengths: torch.autograd.Variable):
    """
    Sort a batch first tensor by some specified lengths.

    Parameters
    ----------
    tensor : Variable(torch.FloatTensor), required.
        A batch first Pytorch tensor.
    sequence_lengths : Variable(torch.LongTensor), required.
        A tensor representing the lengths of some dimension of the tensor which
        we want to sort by.

    Returns
    -------
    sorted_tensor : Variable(torch.FloatTensor)
        The original tensor sorted along the batch dimension with respect to sequence_lengths.
    sorted_sequence_lengths : Variable(torch.LongTensor)
        The original sequence_lengths sorted by decreasing size.
    restoration_indices : Variable(torch.LongTensor)
        Indices into the sorted_tensor such that
        ``sorted_tensor.index_select(0, restoration_indices) == original_tensor``
    permuation_index : Variable(torch.LongTensor)
        The indices used to sort the tensor. This is useful if you want to sort many
        tensors using the same ordering.
    """

    if not isinstance(tensor, Variable) or not isinstance(
            sequence_lengths, Variable):
        raise ConfigurationError(
            "Both the tensor and sequence lengths must be torch.autograd.Variables."
        )

    sorted_sequence_lengths, permutation_index = sequence_lengths.sort(
        0, descending=True)
    sorted_tensor = tensor.index_select(0, permutation_index)

    # This is ugly, but required - we are creating a new variable at runtime, so we
    # must ensure it has the correct CUDA vs non-CUDA type. We do this by cloning and
    # refilling one of the inputs to the function.
    index_range = sequence_lengths.data.clone().copy_(
        torch.arange(0, len(sequence_lengths)))
    # This is the equivalent of zipping with index, sorting by the original
    # sequence lengths and returning the now sorted indices.
    index_range = Variable(index_range.long())
    _, reverse_mapping = permutation_index.sort(0, descending=False)
    restoration_indices = index_range.index_select(0, reverse_mapping)
    return sorted_tensor, sorted_sequence_lengths, restoration_indices, permutation_index
예제 #8
0
 def forward(self, net_output: torch.autograd.Variable,
             y: torch.Tensor) -> torch.autograd.Variable:
     """
     Args:
         net_output: Predicted values, size should be [n, num_class], n is the number of records.
         y: True values.
     Returns: loss function variable
     """
     if self.pre_log:
         log_likelihoods = net_output.gather(dim=1,
                                             index=y.type(
                                                 torch.long).unsqueeze(1))
     else:
         log_likelihoods = net_output.log().gather(
             dim=1, index=y.type(torch.long).unsqueeze(1))
     return -log_likelihoods.mean()
예제 #9
0
파일: rnn.py 프로젝트: Mrpatekful/translate
    def _score(
            self, encoder_output: torch.autograd.Variable,
            decoder_state: torch.autograd.Variable) -> torch.autograd.Variable:
        """
        The score computation is as follows:

            h_d * (W_a * h_eT)

        where h_d is the decoder hidden state, W_a is a linear layer and h_eT is
        the transpose of encoder output at time step t.

        :param encoder_output:
            Variable, (batch_size, 1, hidden_layer) output of the encoder at time step t.

        :param decoder_state:
            Variable, (batch_size, 1, hidden_layer) hidden state of the decoder at time step t.

        :return energy:
            Variable, similarity between the decoder and encoder state.
        """
        energy = self._attention_layer(encoder_output)
        energy = torch.bmm(decoder_state.unsqueeze(1),
                           energy.unsqueeze(1).transpose(1, 2)).squeeze(-1)

        return energy
예제 #10
0
파일: rnn.py 프로젝트: Mrpatekful/translate
    def _forced_decode(self, targets: torch.autograd.Variable, batch_size: int,
                       hidden_state: torch.autograd.Variable,
                       encoder_outputs: torch.autograd.Variable,
                       input_sequence_length: int) -> list:
        """
        This method is primarily used during training, when target outputs are provided to the decoder.
        These target sequences start with an <SOS> token, which will serve as the first input to the _decode
        function. During the decoding iterations the decoder's predictions will only be used as final outputs to
        measure the loss, so the input for the (t)-th time step will be the (t-1)-th element of the provided
        targets.

        :param targets:
            Variable, (batch_size, sequence_length) a batch of word ids.

        :param batch_size:
            int, size of the currently processed batch.

        :param hidden_state:
            Variable, (num_layers * directions, batch_size, hidden_size) initial hidden state.

        :param encoder_outputs:
            Variable, with size of (batch_size, sequence_length, hidden_size).

        :param input_sequence_length:
            int, length of the input (for the encoder) sequence.
        """
        output_sequence_length = targets.size(1) - 1

        inputs = targets[:, :-1].contiguous()
        embedded_inputs = self.embedding(inputs)

        predictions = []

        self._outputs['symbols'] = numpy.zeros(
            (batch_size, output_sequence_length), dtype='int')
        self._outputs['alignment_weights'] = numpy.zeros(
            (batch_size, output_sequence_length, input_sequence_length))

        for step in range(output_sequence_length):
            step_input = embedded_inputs[:, step, :]
            step_input = step_input.unsqueeze(1)
            step_output, hidden_state, attn_weights = self._decode(
                inputs=step_input,
                hidden_state=hidden_state,
                encoder_outputs=encoder_outputs,
                batch_size=batch_size,
                sequence_length=input_sequence_length)

            predictions.append(step_output.squeeze(1))
            self._outputs[
                'alignment_weights'][:, step, :] = attn_weights.data.squeeze(
                    1).cpu().numpy()
            self._outputs['symbols'][:, step] = step_output.topk(
                1)[1].data.squeeze(-1).squeeze(-1).cpu().numpy()

        return predictions
예제 #11
0
def sort_batch_by_length(tensor: torch.autograd.Variable,
                         sequence_lengths: torch.autograd.Variable):

    if not isinstance(tensor, Variable) or not isinstance(
            sequence_lengths, Variable):
        raise Exception(
            "Both the tensor and sequence lengths must be torch.autograd.Variables."
        )

    sorted_sequence_lengths, permutation_index = sequence_lengths.sort(
        0, descending=True)
    sorted_tensor = tensor.index_select(0, permutation_index)

    index_range = sequence_lengths.data.clone().copy_(
        torch.arange(0, len(sequence_lengths)))
    index_range = Variable(index_range.long())
    _, reverse_mapping = permutation_index.sort(0, descending=False)
    restoration_indices = index_range.index_select(0, reverse_mapping)
    return sorted_tensor, sorted_sequence_lengths, restoration_indices, permutation_index
예제 #12
0
def sort_batch_by_length(tensor: torch.autograd.Variable,
                         sequence_lengths: torch.autograd.Variable):
    if not isinstance(tensor, Variable) or not isinstance(
            sequence_lengths, Variable):
        raise ValueError("Both the tensor and sequence lengths must "
                         "be torch.autograd.Variables.")

    sorted_sequence_lengths, permutation_index = sequence_lengths.sort(
        0, descending=True)
    sorted_tensor = tensor.index_select(0, permutation_index)

    # This is ugly, but required - we are creating a new variable at runtime, so we
    # must ensure it has the correct CUDA vs non-CUDA type. We do this by cloning and
    # refilling one of the inputs to the function.
    index_range = sequence_lengths.data.clone().copy_(
        torch.arange(0, len(sequence_lengths)))
    # This is the equivalent of zipping with index, sorting by the original
    # sequence lengths and returning the now sorted indices.
    index_range = Variable(index_range.long())
    _, reverse_mapping = permutation_index.sort(0, descending=False)
    restoration_indices = index_range.index_select(0, reverse_mapping)
    return sorted_tensor, sorted_sequence_lengths, restoration_indices, permutation_index
예제 #13
0
def sort_batch_by_length(tensor: torch.autograd.Variable, sequence_lengths: torch.autograd.Variable):
    """
    Sort a batch first tensor by some specified lengths.

    Parameters
    ----------
    tensor : Variable(torch.FloatTensor), required.
        A batch first Pytorch tensor.
    sequence_lengths : Variable(torch.LongTensor), required.
        A tensor representing the lengths of some dimension of the tensor which
        we want to sort by.

    Returns
    -------
    sorted_tensor : Variable(torch.FloatTensor)
        The original tensor sorted along the batch dimension with respect to sequence_lengths.
    sorted_sequence_lengths : Variable(torch.LongTensor)
        The original sequence_lengths sorted by decreasing size.
    restoration_indices : Variable(torch.LongTensor)
        Indices into the sorted_tensor such that
        ``sorted_tensor.index_select(0, restoration_indices) == original_tensor``
    """

    if not isinstance(tensor, Variable) or not isinstance(sequence_lengths, Variable):
        raise ConfigurationError("Both the tensor and sequence lengths must be torch.autograd.Variables.")

    sorted_sequence_lengths, permutation_index = sequence_lengths.sort(0, descending=True)
    sorted_tensor = tensor.index_select(0, permutation_index)

    # This is ugly, but required - we are creating a new variable at runtime, so we
    # must ensure it has the correct CUDA vs non-CUDA type. We do this by cloning and
    # refilling one of the inputs to the function.
    index_range = sequence_lengths.data.clone().copy_(torch.arange(0, len(sequence_lengths)))
    # This is the equivalent of zipping with index, sorting by the original
    # sequence lengths and returning the now sorted indices.
    index_range = Variable(index_range.long())
    _, reverse_mapping = permutation_index.sort(0, descending=False)
    restoration_indices = index_range.index_select(0, reverse_mapping)
    return sorted_tensor, sorted_sequence_lengths, restoration_indices
예제 #14
0
파일: rnn.py 프로젝트: Mrpatekful/translate
    def forward(self,
                encoder_outputs: torch.autograd.Variable,
                hidden_state: torch.autograd.Variable,
                targets: torch.autograd.Variable = None,
                max_length: int = None) -> tuple:
        """
        Forward step of the decoder unit. If the targets parameter is not None, then teacher forcing is used,
        so during decoding, the previous output word will be provided at time step t. If targets is None, decoding
        follows the general method, when the input word for the recurrent unit at time step t, is the output word at
        time step t-1.

        :param targets:
            Variable, (batch_size, sequence_length) a batch of word ids. If None, then normal teacher
            forcing is not applied.

        :param max_length:
            int, maximum length of the decoded sequence. If None, the maximum length parameter from
            the configuration file will be used as maximum length. This parameter has no effect, if
            targets parameter is provided, because in that case, the length of the target sequence
            will be decoding length.

        :param encoder_outputs:
            Variable, with size of (batch_size, sequence_length, hidden_size). This parameter
            is redundant for the standard decoder unit.

        :param hidden_state:
            Variable, (num_layers * directions, batch_size, hidden_size) initial hidden state.

        :return decoder_outputs:
            dict, containing two string keys, symbols: Ndarray, the decoded word ids.
        """
        batch_size = encoder_outputs.size(0)

        if targets is not None:

            predictions = self._forced_decode(targets=targets,
                                              batch_size=batch_size,
                                              hidden_state=hidden_state,
                                              encoder_outputs=encoder_outputs,
                                              input_sequence_length=None)

        else:

            predictions = self._predictive_decode(
                max_length=max_length,
                batch_size=batch_size,
                hidden_state=hidden_state,
                encoder_outputs=encoder_outputs,
                input_sequence_length=None)

        return self._outputs, predictions
예제 #15
0
파일: rnn.py 프로젝트: Mrpatekful/translate
    def _forced_decode(self,
                       targets: torch.autograd.Variable,
                       batch_size: int,
                       hidden_state: torch.autograd.Variable,
                       encoder_outputs: torch.autograd.Variable,
                       input_sequence_length: int = None) -> list:
        """
        This method is primarily used during training, when target outputs are provided to the decoder.
        These target sequences start with an <SOS> token, which will serve as the first input to the _decode
        function. During the decoding iterations the decoder's predictions will only be used as final outputs to
        measure the loss, so the input for the (t)-th time step will be the (t-1)-th element of the provided
        targets.

        :param targets:
            Variable, (batch_size, sequence_length) a batch of word ids.

        :param batch_size:
            int, size of the currently processed batch.

        :param hidden_state:
            Variable, (num_layers * directions, batch_size, hidden_size) initial hidden state.

        :param encoder_outputs:
            Variable, with size of (batch_size, sequence_length, hidden_size).

        :param input_sequence_length:
            This parameter is required only by the attentional version of this method.
        """
        output_sequence_length = targets.size(1) - 1

        self._outputs['symbols'] = numpy.zeros(
            (batch_size, output_sequence_length), dtype=numpy.int32)

        predictions = []

        inputs = targets[:, :-1].contiguous()
        embedded_inputs = self.embedding(inputs)

        outputs, hidden_state, _ = self._decode(inputs=embedded_inputs,
                                                hidden_state=hidden_state,
                                                encoder_outputs=None,
                                                batch_size=batch_size,
                                                sequence_length=None)

        for step in range(output_sequence_length):
            self._outputs['symbols'][:, step] = outputs[:, step, :].topk(
                1)[1].squeeze(-1).data.cpu().numpy()
            predictions.append(outputs[:, step, :])

        return predictions
예제 #16
0
파일: rnn.py 프로젝트: Mrpatekful/translate
    def _calculate_context(self, decoder_state: torch.autograd.Variable,
                           encoder_outputs: torch.autograd.Variable,
                           batch_size: int, sequence_length: int) -> tuple:
        """
        Calculates the context for the decoder, given the encoder outputs and a decoder hidden state.
        The algorithm iterates through the encoder outputs and scores each output based on the similarity
        with the decoder state. The scoring functions are implemented in the child nodes of this class.

        :param decoder_state:
            Variable, (batch_size, 1, hidden_size) the state of the decoder.

        :param encoder_outputs:
            Variable, (batch_size, sequence_length, hidden_size) the output of
            each time step of the encoder.

        :param batch_size:
            int, size of the input batch.

        :param sequence_length:
            int, size of the sequence.

        :return context:
                Variable, the weighted sum of the encoder outputs.

        :return attn_weights:
            Variable, weights used in the calculation of the context.
        """
        attn_energies = torch.zeros([batch_size, sequence_length])

        if self._cuda:
            attn_energies = attn_energies.cuda()

        attn_energies = autograd.Variable(attn_energies)

        squeezed_output = decoder_state.squeeze(1)
        for step in range(sequence_length):
            attn_energies[:, step] = self._score(encoder_outputs[:, step],
                                                 squeezed_output)

        attn_weights = functional.softmax(attn_energies, dim=1).unsqueeze(1)
        context = torch.bmm(attn_weights, encoder_outputs)

        return context, attn_weights
예제 #17
0
    def probabilities(self,
                      states: torch.autograd.Variable,
                      training: bool = True) -> np.ndarray:
        epsilon = self._epsilon if training else 0

        q_values = self._model.q_values(states)
        # noinspection PyArgumentList
        _, argmax = torch.max(q_values, dim=1)
        batch_size = states.size()[0]
        probabilities: torch.FloatTensor = torch.ones((batch_size, self._model.num_actions)) * \
            epsilon / self._model.num_actions
        arange = torch.arange(0, batch_size).type(torch.LongTensor)
        if self._model.is_cuda:
            probabilities = probabilities.cuda()
            arange = arange.cuda()
        probabilities[arange, argmax.data] += (1 - epsilon)
        if self._model.is_cuda:
            return probabilities.cpu().numpy()[0]
        else:
            return probabilities.numpy()[0]
예제 #18
0
파일: ops.py 프로젝트: Macbull/room_layout
 def __call__(self, x: torch.autograd.Variable):
     assert x.dim() == 4, 'input tensor should be 4D'
     return F.conv2d(x,
                     self.kernel,
                     padding=self.padding,
                     dilation=self.dilation)
예제 #19
0
 def register(self, variable: torch.autograd.Variable, label):
     variable.register_hook(self.create_hook(label))
     return self
예제 #20
0
 def set_embedding(self, name: str, vector: torch.autograd.Variable):
     self._embeddings[name] = vector.cpu()
예제 #21
0
    def forward(self, embedded_sents: torch.autograd.Variable, pos: list,
                gold_heads: list, data_lengths: np.ndarray):
        """
        Returns the logits for the labels and the arcs
        :return:
        """

        # Sort lens (pos lens are the same)
        sorted_indices = np.argsort(-np.array(data_lengths)).astype(np.int64)

        # Extract sorted lens by index
        lengths = data_lengths[sorted_indices]
        max_len = lengths[0]

        embedded_sents = self.embedding_dropout(embedded_sents[:, :max_len])

        pos_t = torch.LongTensor(pos)[:, :max_len]
        embedded_pos = self.pos_embed(torch.autograd.Variable(pos_t))
        embedded_pos = self.embedding_dropout(embedded_pos)

        # Extract tensors ordered by len
        stacked_x = embedded_sents.index_select(
            dim=0,
            index=torch.autograd.Variable(
                torch.from_numpy(sorted_indices))).to(self.device)
        stacked_pos_x = embedded_pos.index_select(
            dim=0,
            index=torch.autograd.Variable(
                torch.from_numpy(sorted_indices))).to(self.device)

        # Apply dropout and when one is dropped scale the other
        mask_words = stacked_x - stacked_pos_x == stacked_x

        mask_pos = stacked_pos_x - stacked_x == stacked_pos_x

        stacked_x[mask_words] *= 2
        stacked_pos_x[mask_pos] *= 2

        stacked_x = torch.cat((stacked_x, stacked_pos_x), dim=2)

        stacked_x = nn.utils.rnn.pack_padded_sequence(
            stacked_x,
            torch.from_numpy(lengths).to(self.device),
            batch_first=True)

        x_lstm, _ = self.lstm(stacked_x,
                              self.init_hidden_trainable(len(embedded_sents)))

        x_lstm, _ = nn.utils.rnn.pad_packed_sequence(x_lstm, batch_first=True)

        # Reorder the batch
        x_lstm = x_lstm.index_select(dim=0,
                                     index=torch.autograd.Variable(
                                         torch.from_numpy(
                                             np.argsort(sorted_indices).astype(
                                                 np.int64)).to(self.device)))

        # NN scoring
        h_arc_dep = self.arc_dep_mlp(x_lstm)
        h_arc_dep = self.activation(h_arc_dep)
        h_arc_dep = self.dropout(h_arc_dep)

        h_arc_head = self.arc_head_mlp(x_lstm)
        h_arc_head = self.activation(h_arc_head)
        h_arc_head = self.dropout(h_arc_head)

        h_label_dep = self.label_dep_mlp(x_lstm)
        h_label_dep = self.activation(h_label_dep)
        h_label_dep = self.dropout(h_label_dep)

        h_label_head = self.label_head_mlp(x_lstm)
        h_label_head = self.activation(h_label_head)
        h_label_head = self.dropout(h_label_head)

        # Heads computation
        s_i_arc = biaffine(h_arc_dep,
                           self.W_arc,
                           h_arc_head,
                           self.device,
                           num_outputs=1,
                           bias_x=True)

        # Labels computation
        full_label_logits = biaffine(h_label_dep,
                                     self.W_label,
                                     h_label_head,
                                     self.device,
                                     num_outputs=self.N_CLASSES,
                                     bias_x=True,
                                     bias_y=True)

        if self.training:
            gold_heads_t = torch.LongTensor(gold_heads)[:, :max_len].to(
                self.device)
            m = (gold_heads_t == self.heads_vocab['<PAD>'])
            gold_heads_t[m] *= 0
            pred_arcs = gold_heads_t
        else:
            pred_arcs = s_i_arc.argmax(-1)

        # Gather label logits from predicted or gold heads
        pred_arcs = pred_arcs.unsqueeze(2).unsqueeze(
            3)  # [batch, sent_len, 1, 1]
        pred_arcs = pred_arcs.expand(
            -1, -1, -1,
            full_label_logits.size(-1))  # [batch, sent_len, 1, n_labels]
        selected_label_logits = torch.gather(
            full_label_logits, 2,
            pred_arcs).squeeze(2)  # [batch, n_labels, sent_len]

        return s_i_arc, selected_label_logits, max_len, data_lengths