예제 #1
0
def masking_blockdiagonal(passage_length, window, device_id):
    """ Make a (passage_length, passage_length) tensor M of 1 and -1 in which for each row x,
        M[:, x, y] = -1 if y < x - window or y > x + window, else it is 1.
        Basically for the x-th row, the [x-win, x+win] columns should be 1, and rest -1
    """

    lower_limit = [max(0, i - window) for i in range(passage_length)]
    upper_limit = [
        min(passage_length, i + window) for i in range(passage_length)
    ]

    # Tensors of lower and upper limits for each row
    lower = allenutil.move_to_device(torch.LongTensor(lower_limit),
                                     cuda_device=device_id)
    upper = allenutil.move_to_device(torch.LongTensor(upper_limit),
                                     cuda_device=device_id)
    lower_un = lower.unsqueeze(1)
    upper_un = upper.unsqueeze(1)

    # Range vector for each row
    lower_range_vector = allenutil.get_range_vector(
        passage_length, device=device_id).unsqueeze(0)
    upper_range_vector = allenutil.get_range_vector(
        passage_length, device=device_id).unsqueeze(0)

    # Masks for lower and upper limits of the mask
    lower_mask = lower_range_vector >= lower_un
    upper_mask = upper_range_vector <= upper_un

    # Final-mask that we require
    # Shape: (passage_length, passage_length); (passage_length, passage_length)
    inwindow_mask = (lower_mask == upper_mask).float()
    outwindow_mask = (lower_mask != upper_mask).float()

    return inwindow_mask, outwindow_mask
    def forward(self, inputs: torch.Tensor, offsets: torch.Tensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        inputs: ``torch.Tensor``, required
            A ``(batch_size, num_timesteps)`` tensor representing the byte-pair encodings
            for the current batch.
        offsets: ``torch.Tensor``, required
            A ``(batch_size, max_sequence_length)`` tensor representing the word offsets
            for the current batch.

        Returns
        -------
        ``[torch.Tensor]``
            An embedding representation of the input sequence
            having shape ``(batch_size, sequence_length, embedding_dim)``
        """
        # pylint: disable=arguments-differ
        batch_size, num_timesteps = inputs.size()

        # the transformer embedding consists of the byte pair embeddings,
        # the special embeddings and the position embeddings.
        # the position embeddings are always at least self._transformer.n_ctx,
        # but may be longer.
        # the transformer "vocab" consists of the actual vocab and the
        # positional encodings. Here we want the count of just the former.
        vocab_size = self._transformer.vocab_size - self._transformer.n_ctx

        # vocab_size, vocab_size + 1, ...
        positional_encodings = get_range_vector(num_timesteps, device=get_device_of(inputs)) + vocab_size

        # Combine the inputs with positional encodings
        batch_tensor = torch.stack([
                inputs,   # (batch_size, num_timesteps)
                positional_encodings.expand(batch_size, num_timesteps)
        ], dim=-1)

        byte_pairs_mask = inputs != 0

        # Embeddings is num_output_layers x (batch_size, num_timesteps, embedding_dim)
        layer_activations = self._transformer(batch_tensor)

        # Output of scalar_mix is (batch_size, num_timesteps, embedding_dim)
        if self._top_layer_only:
            mix = layer_activations[-1]
        else:
            mix = self._scalar_mix(layer_activations, byte_pairs_mask)

        # These embeddings are one per byte-pair, but we want one per original _word_.
        # So we choose the embedding corresponding to the last byte pair for each word,
        # which is captured by the ``offsets`` input.
        if offsets is not None:
            range_vector = get_range_vector(batch_size, device=get_device_of(mix)).unsqueeze(1)
            last_byte_pair_embeddings = mix[range_vector, offsets]
        else:
            # allow to return all byte pairs by passing no offsets
            seq_len = (byte_pairs_mask > 0).long().sum(dim=1).max()
            last_byte_pair_embeddings = mix[:, :seq_len]

        return last_byte_pair_embeddings
예제 #3
0
    def forward(self, input_tensor: torch.Tensor):
        """
        Adds a positional encoding to `input_tensor`.
        """
        # TODO: Another option is to specify the expected size in init, so that we can construct
        # the positional encoding beforehand, and simply add it to the input tensor in forward.
        _, timesteps, hidden_dim = input_tensor.size()
        num_timescales = hidden_dim // 2
        device = get_device_of(input_tensor)

        timestep_range = get_range_vector(timesteps, device).data.float()
        timescale_range = get_range_vector(num_timescales, device).data.float()

        log_timescale_increments = math.log(
            float(self.max_timescale) /
            float(self.min_timescale)) / float(num_timescales - 1)
        inverse_timescales = self.min_timescale * torch.exp(
            timescale_range * -log_timescale_increments)

        # Broadcasted multiplication - shape (timesteps, num_timescales)
        scaled_time = timestep_range.unsqueeze(
            1) * inverse_timescales.unsqueeze(0)
        # shape (timesteps, 2 * num_timescales)
        sinusoids = torch.cat([torch.sin(scaled_time),
                               torch.cos(scaled_time)], 1)
        if hidden_dim % 2 != 0:
            # if the number of dimensions is odd, the cos and sin
            # timescales had size (hidden_dim - 1) / 2, so we need
            # to add a row of zeros to make up the difference.
            sinusoids = torch.cat(
                [sinusoids, sinusoids.new_zeros(timesteps, 1)], 1)
        return input_tensor + sinusoids.unsqueeze(0)
예제 #4
0
def masking_blockdiagonal(passage_length, window, device_id):
    """ Make a (passage_length, passage_length) tensor M of 1 and -1 in which for each row x,
        M[x, y] = -1 if y < x - window or y > x + window, else it is 1.
        Basically for the x-th row, the [x-win, x+win] columns should be 1, and rest -1
    """

    # The lower and upper limit of token-idx that won't be masked for a given token
    lower = allenutil.get_range_vector(passage_length,
                                       device=device_id) - window
    upper = allenutil.get_range_vector(passage_length,
                                       device=device_id) + window
    lower = torch.clamp(lower, min=0, max=passage_length - 1)
    upper = torch.clamp(upper, min=0, max=passage_length - 1)
    lower_un = lower.unsqueeze(1)
    upper_un = upper.unsqueeze(1)

    # Range vector for each row
    lower_range_vector = allenutil.get_range_vector(
        passage_length, device=device_id).unsqueeze(0)
    upper_range_vector = allenutil.get_range_vector(
        passage_length, device=device_id).unsqueeze(0)

    # Masks for lower and upper limits of the mask
    lower_mask = lower_range_vector >= lower_un
    upper_mask = upper_range_vector <= upper_un

    # Final-mask that we require
    inwindow_mask = (lower_mask == upper_mask).float()
    outwindow_mask = (lower_mask != upper_mask).float()

    return inwindow_mask, outwindow_mask
예제 #5
0
    def _generate_valid_antecedents(
        num_spans_to_keep: int, max_antecedents: int, device: int
    ) -> Tuple[torch.IntTensor, torch.IntTensor, torch.BoolTensor]:
        """
        This method generates possible antecedents per span which survived the pruning
        stage. This procedure is `generic across the batch`. The reason this is the case is
        that each span in a batch can be coreferent with any previous span, but here we
        are computing the possible `indices` of these spans. So, regardless of the batch,
        the 1st span _cannot_ have any antecedents, because there are none to select from.
        Similarly, each element can only predict previous spans, so this returns a matrix
        of shape (num_spans_to_keep, max_antecedents), where the (i,j)-th index is equal to
        (i - 1) - j if j <= i, or zero otherwise.

        # Parameters

        num_spans_to_keep : `int`, required.
            The number of spans that were kept while pruning.
        max_antecedents : `int`, required.
            The maximum number of antecedent spans to consider for every span.
        device : `int`, required.
            The CUDA device to use.

        # Returns

        valid_antecedent_indices : `torch.LongTensor`
            The indices of every antecedent to consider with respect to the top k spans.
            Has shape `(num_spans_to_keep, max_antecedents)`.
        valid_antecedent_offsets : `torch.LongTensor`
            The distance between the span and each of its antecedents in terms of the number
            of considered spans (i.e not the word distance between the spans).
            Has shape `(1, max_antecedents)`.
        valid_antecedent_mask : `torch.BoolTensor`
            The mask representing whether each antecedent span is valid. Required since
            different spans have different numbers of valid antecedents. For example, the first
            span in the document should have no valid antecedents.
            Has shape `(1, num_spans_to_keep, max_antecedents)`.
        """
        # Shape: (num_spans_to_keep, 1)
        target_indices = util.get_range_vector(num_spans_to_keep,
                                               device).unsqueeze(1)

        # Shape: (1, max_antecedents)
        valid_antecedent_offsets = (
            util.get_range_vector(max_antecedents, device) + 1).unsqueeze(0)

        # This is a broadcasted subtraction.
        # Shape: (num_spans_to_keep, max_antecedents)
        raw_antecedent_indices = target_indices - valid_antecedent_offsets

        # In our matrix of indices, the upper triangular part will be negative
        # because the offsets will be > the target indices. We want to mask these,
        # because these are exactly the indices which we don't want to predict, per span.
        # Shape: (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_mask = (raw_antecedent_indices >= 0).unsqueeze(0)

        # Shape: (num_spans_to_keep, max_antecedents)
        valid_antecedent_indices = F.relu(
            raw_antecedent_indices.float()).long()
        return valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_mask
예제 #6
0
    def _construct_loss(
        self,
        head_tag_representation: torch.Tensor,
        child_tag_representation: torch.Tensor,
        attended_arcs: torch.Tensor,
        head_indices: torch.Tensor,
        head_tags: torch.Tensor,
        mask: torch.Tensor,
        head_tag_temperature: Optional[float] = None,
        head_temperature: Optional[float] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        float_mask = mask.float()
        tag_mask = self._get_unknown_tag_mask(mask, head_tags)

        batch_size, sequence_length, _ = attended_arcs.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(
            batch_size, get_device_of(attended_arcs)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        if head_temperature:
            attended_arcs /= head_temperature
        normalised_arc_logits = masked_log_softmax(
            attended_arcs,
            mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)

        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation,
                                              child_tag_representation,
                                              head_indices)
        if head_tag_temperature:
            head_tag_logits /= head_tag_temperature
        normalised_head_tag_logits = masked_log_softmax(
            head_tag_logits,
            tag_mask.unsqueeze(-1)) * tag_mask.float().unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length,
                                          get_device_of(attended_arcs))
        child_index = timestep_index.view(1, sequence_length).expand(
            batch_size, sequence_length).long()
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index,
                                         head_indices]
        tag_loss = normalised_head_tag_logits[range_vector, child_index,
                                              head_tags]
        tag_loss *= (head_tags > 1).float()
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum() / valid_positions.float()
        tag_nll = -tag_loss.sum() / valid_positions.float()
        loss = arc_nll + tag_nll
        return loss, normalised_arc_logits, normalised_head_tag_logits
예제 #7
0
    def loss(self, edge_scores: torch.Tensor, head_indices: torch.Tensor,
             mask: torch.Tensor) -> torch.Tensor:
        """
        Computes the edge loss for a sequence given gold head indices and tags.

        Parameters
        ----------
        edge_scores : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The indices of the heads for every word.
        head_tags : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The dependency labels of the heads for every word.
        mask : ``torch.Tensor``, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        Returns
        -------
        arc_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc loss.
        """
        float_mask = mask.float()
        batch_size, sequence_length, _ = edge_scores.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(
            batch_size, get_device_of(edge_scores)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        normalised_arc_logits = masked_log_softmax(
            edge_scores,
            mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)

        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length,
                                          get_device_of(edge_scores))
        child_index = timestep_index.view(1, sequence_length).expand(
            batch_size, sequence_length).long()
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index,
                                         head_indices]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum()
        if self.normalize_wrt_seq_len:
            arc_nll /= valid_positions.float()
        return arc_nll
예제 #8
0
    def forward(self, inputs: torch.Tensor,
                offsets: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        inputs: ``torch.Tensor``, required
            A ``(batch_size, num_timesteps)`` tensor representing the byte-pair encodings
            for the current batch.
        offsets: ``torch.Tensor``, required
            A ``(batch_size, max_sequence_length)`` tensor representing the word offsets
            for the current batch.

        Returns
        -------
        ``[torch.Tensor]``
            An embedding representation of the input sequence
            having shape ``(batch_size, sequence_length, embedding_dim)``
        """
        # pylint: disable=arguments-differ
        batch_size, num_timesteps = inputs.size()

        # the transformer "vocab" consists of the actual vocab and the
        # positional encodings. Here we want the count of just the former.
        vocab_size = self._transformer.vocab_size - self._transformer.n_ctx

        # vocab_size, vocab_size + 1, ...
        positional_encodings = get_range_vector(
            num_timesteps, device=get_device_of(inputs)) + vocab_size

        # Combine the inputs with positional encodings
        batch_tensor = torch.stack(
            [
                inputs,  # (batch_size, num_timesteps)
                positional_encodings.expand(batch_size, num_timesteps)
            ],
            dim=-1)

        byte_pairs_mask = inputs != 0

        # Embeddings is num_output_layers x (batch_size, num_timesteps, embedding_dim)
        layer_activations = self._transformer(batch_tensor)

        # Output of scalar_mix is (batch_size, num_timesteps, embedding_dim)
        mix = self._scalar_mix(layer_activations, byte_pairs_mask)

        # These embeddings are one per byte-pair, but we want one per original _word_.
        # So we choose the embedding corresponding to the last byte pair for each word,
        # which is captured by the ``offsets`` input.
        range_vector = get_range_vector(batch_size,
                                        device=get_device_of(mix)).unsqueeze(1)
        last_byte_pair_embeddings = mix[range_vector, offsets]

        return last_byte_pair_embeddings
예제 #9
0
    def loss(self, edge_label_logits: torch.Tensor, mask: torch.Tensor,
             head_tags: torch.Tensor) -> torch.Tensor:
        """
        Computes the arc and tag loss for a sequence given gold head indices and tags.

        Parameters
        ----------
        edge_label_logits : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            that contains raw predictions for incoming edge labels
        head_tags : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The dependency labels of the heads for every word.
        mask : ``torch.Tensor``, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        Returns
        -------
        tag_nll : ``torch.Tensor``, required.
            The negative log likelihood from the edge label loss.
        """
        float_mask = mask.float()
        batch_size, sequence_length, _ = edge_label_logits.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(
            batch_size, get_device_of(edge_label_logits)).unsqueeze(1)

        # shape (batch_size, sequence_length, num_head_tags)
        normalised_edge_label_logits = masked_log_softmax(
            edge_label_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length,
                                          get_device_of(edge_label_logits))
        child_index = timestep_index.view(1, sequence_length).expand(
            batch_size, sequence_length).long()
        # shape (batch_size, sequence_length)
        tag_loss = normalised_edge_label_logits[range_vector, child_index,
                                                head_tags]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        if self.normalize_wrt_seq_len:
            return -tag_loss.sum() / valid_positions.float()
        else:
            return -tag_loss.sum()
예제 #10
0
    def _construct_loss(
            self, head_tag_representation: torch.Tensor,
            child_tag_representation: torch.Tensor,
            attended_arcs: torch.Tensor, head_indices: torch.Tensor,
            head_tags: torch.Tensor,
            mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        float_mask = mask.float()

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        batch_size, sequence_length, _ = attended_arcs.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(
            batch_size, get_device_of(attended_arcs)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        normalised_arc_logits = masked_log_softmax(
            attended_arcs,
            mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)

        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation,
                                              child_tag_representation,
                                              head_indices)
        normalised_head_tag_logits = torch.nn.functional.log_softmax(
            head_tag_logits, dim=-1) * float_mask.unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length,
                                          get_device_of(attended_arcs))
        child_index = timestep_index.view(1, sequence_length).expand(
            batch_size, sequence_length).long()
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index,
                                         head_indices]
        tag_loss = normalised_head_tag_logits[range_vector, child_index,
                                              head_tags]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum() / valid_positions.float()
        tag_nll = -tag_loss.sum() / valid_positions.float()
        return arc_nll, tag_nll
예제 #11
0
    def forward(self, inputs: torch.Tensor, mask: torch.Tensor,
                span: torch.Tensor) -> torch.Tensor:
        # pylint: disable=arguments-differ

        # input -> [B x seq_len x d], offset -> [B x 2]
        batch_size, seq_len, _ = inputs.size()

        pos_range = util.get_range_vector(seq_len,
                                          util.get_device_of(inputs)).repeat(
                                              (batch_size, 1))

        start_offset = span[:, 0].unsqueeze(dim=1)
        end_offset = span[:, 1].unsqueeze(dim=1)

        left_mask = torch.lt(pos_range, start_offset).long()
        middle_mask = (torch.ge(pos_range, start_offset) *
                       torch.le(pos_range, end_offset)).long()
        right_mask = torch.gt(pos_range, end_offset).long()

        offsets = start_offset * left_mask + end_offset * right_mask

        relative_positions = (1 + self._n_position + (pos_range - offsets) *
                              (1 - middle_mask))

        # mask padding so it won't receive a positional embedding
        relative_positions = relative_positions * mask.long()

        return self._embedding(relative_positions)
예제 #12
0
 def set_input(self, encoded_input: torch.Tensor,
               mask: torch.Tensor) -> None:
     self.encoded_input = encoded_input
     batch_size = encoded_input.shape[0]
     self.mask = mask
     self.batch_size_range = get_range_vector(batch_size,
                                              get_device_of(encoded_input))
예제 #13
0
    def forward(self, word_inputs: torch.Tensor, char_inputs: torch.Tensor):
        embs = []
        if self.word_embedder is not None:
            word_inputs = torch.autograd.Variable(word_inputs,
                                                  requires_grad=False)
            embed_words = self.word_embedder(word_inputs)
            embs.append(embed_words)

        if self.char_embedder is not None:
            char_inputs, char_lengths = char_inputs
            batch_size, seq_len = char_lengths.size()[:2]
            char_inputs = char_inputs.view(batch_size * seq_len, -1)
            char_lengths = char_lengths.view(batch_size * seq_len, -1)

            # (batch_size * seq_len, max_char, dim)
            embeded_chars = self.char_embedder(char_inputs)
            _, max_seq_len, dim = embeded_chars.size()

            layer = embeded_chars
            for length in range(1, max_seq_len):
                new_layer = layer.new_zeros(layer.size())
                range_vector = get_range_vector(max_seq_len,
                                                get_device_of(char_lengths))
                mask = ((range_vector.unsqueeze(0) - char_lengths + length) <=
                        0).unsqueeze(-1)
                for i in range(max_seq_len - length):
                    new_layer[:, i, :] = self.cell(layer[:, i:i + 2, :])
                layer.masked_scatter_(mask, new_layer)
            embs.append(layer[:, 0, :].view(batch_size, seq_len, dim))

        token_embedding = torch.cat(embs, dim=2)

        return self.projection(token_embedding)
예제 #14
0
    def forward(
            self,
            sequence_tensor: torch.FloatTensor,
            span_indices: torch.LongTensor,
            span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(
            max_batch_span_width,
            util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(
            raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(
            span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor,
                                                    span_indices,
                                                    flat_span_indices)

        #  text_embeddings = span_embeddings * span_mask.unsqueeze(-1)
        batch_size, num_spans, max_batch_span_width, _ = span_embeddings.size()

        view_text_embeddings = span_embeddings.view(batch_size * num_spans,
                                                    max_batch_span_width, -1)
        span_mask = span_mask.view(batch_size * num_spans,
                                   max_batch_span_width)
        cnn_text_embeddings = self.cnn(view_text_embeddings, span_mask)
        cnn_text_embeddings = cnn_text_embeddings.view(batch_size, num_spans,
                                                       self._output_dim)
        return cnn_text_embeddings
예제 #15
0
    def _get_head_tags(self, head_tag_representation: torch.Tensor,
                       child_tag_representation: torch.Tensor,
                       head_indices: torch.Tensor) -> torch.Tensor:
        """
        Decodes the head tags given the head and child tag representations
        and a tensor of head indices to compute tags for. Note that these are
        either gold or predicted heads, depending on whether this function is
        being called to compute the loss, or if it's being called during inference.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length). The indices of the heads
            for every word.

        Returns
        -------
        head_tag_logits : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            representing logits for predicting a distribution over tags
            for each arc.
        """
        batch_size = head_tag_representation.size(0)
        # shape (batch_size,)
        # print('hello_ in 576')
        # print(batch_size)
        range_vector = get_range_vector(
            batch_size, get_device_of(head_tag_representation)).unsqueeze(1)

        # This next statement is quite a complex piece of indexing, which you really
        # need to read the docs to understand. See here:
        # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing
        # In effect, we are selecting the indices corresponding to the heads of each word from the
        # sequence length dimension for each element in the batch.

        # shape (batch_size, sequence_length, tag_representation_dim)
        # print('hello_ in 587')
        # print(head_indices)
        # print(head_indices.shape)
        # print(range_vector)
        # # print(head_tag_representation)
        # print(head_tag_representation.shape)
        # # print(child_tag_representation)
        # print(child_tag_representation.shape)
        selected_head_tag_representations = head_tag_representation[
            range_vector, head_indices]
        selected_head_tag_representations = selected_head_tag_representations.contiguous(
        )
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self.tag_bilinear(selected_head_tag_representations,
                                            child_tag_representation)
        return head_tag_logits
def get_last_sent_gate(all_predictions, num_spans, get_all_beam, eos_idx=0):
    """ all_predictions: shape: (batch_size, K, num_decoding_steps)
    """
    batch_size = all_predictions.size(0)
    num_steps = all_predictions.size(-1)
    # shape: (batch_size, K)
    last_poses = torch.sum((all_predictions != eos_idx).float(), dim=-1) - 1
    # shape: (num_decoding_steps, )
    indices = get_range_vector(num_steps, get_device_of(all_predictions)).float()
    # shape: (batch_size, K, num_decoding_steps)
    mask = (indices.view(*([1]*(all_predictions.dim()-1)), num_steps) == last_poses.unsqueeze(-1)).float()
    # shape: (batch_size, K, num_decoding_steps)
    last_predictions = all_predictions.float() * mask
    print("last_predictions:", last_predictions)

    # build the last sent gate. The dim is set to 1 + num_spans to account for the end embedding
    # shape: (batch_size, 1+num_spans) or (batch_size, K, 1+num_spans)
    if not get_all_beam:
        gate = last_predictions.new_zeros((batch_size, 1+num_spans))
    else:
        beam = all_predictions.size(1)
        gate = last_predictions.new_zeros((batch_size, beam, 1+num_spans))
    gate.scatter_(-1, last_predictions.long(), 1.)
    # remove the column for end embedding
    # shape: (batch_size, num_spans) or (batch_size, K, num_spans)
    gate = gate[..., 1:]

    # shape: (batch_size * num_spans, 1) or (batch_size * K * num_spans, 1)
    if not get_all_beam:
        gate = gate.reshape(batch_size * num_spans, 1)
    else:
        gate = gate.reshape(batch_size * beam * num_spans, 1)
    return gate
예제 #17
0
def get_timing_signal_1d(length,
                         channels,
                         device,
                         min_timescale=1.0,
                         max_timescale=1.0e4,
                         start_index=0):
    """Gets a bunch of sinusoids of different frequencies.
	  Each channel of the input Tensor is incremented by a sinusoid of a different
	  frequency and phase.
	  This allows attention to learn to use absolute and relative positions.
	  Timing signals should be added to some precursors of both the query and the
	  memory inputs to attention.
	  The use of relative position is possible because sin(x+y) and cos(x+y) can be
	  expressed in terms of y, sin(x) and cos(x).
	  In particular, we use a geometric sequence of timescales starting with
	  min_timescale and ending with max_timescale.  The number of different
	  timescales is equal to channels / 2. For each timescale, we
	  generate the two sinusoidal signals sin(timestep/timescale) and
	  cos(timestep/timescale).  All of these sinusoids are concatenated in
	  the channels dimension.
	  Args:
	    length: scalar, length of timing signal sequence.
	    channels: scalar, size of timing embeddings to create. The number of
	        different timescales is equal to channels / 2.
	    min_timescale: a float
	    max_timescale: a float
	    start_index: index of first position
	  Returns:
	    a Tensor of timing signals [1, length, channels]
	"""
    position = util.get_range_vector(length, device) + start_index
    position = position.float()
    num_timescales = channels // 2
    log_timescale_increment = (
        math.log(float(max_timescale) / float(min_timescale)) /
        max(num_timescales - 1.0, 1.0))
    inv_timescales = min_timescale * torch.exp(
        util.get_range_vector(num_timescales, device).float() *
        -log_timescale_increment)
    scaled_time = torch.unsqueeze(position, 1) * torch.unsqueeze(
        inv_timescales, 0)
    signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
    pad = nn.ConstantPad1d((0, channels % 2), 0)
    signal = pad(signal)
    signal = signal.view(1, length, channels)
    return signal
예제 #18
0
    def test_openai_transformer_matches_tensorflow(self):
        model_path = "https://allennlp.s3.amazonaws.com/models/openai-transformer-lm-2018.07.23.tar.gz"
        indexer = OpenaiTransformerBytePairIndexer(model_path=model_path)
        transformer = OpenaiTransformer(model_path=model_path)

        # get the test sentences
        with open(self.FIXTURES_ROOT / 'openai_transformer' / 'text.txt',
                  'r') as fin:
            sentences = fin.read().strip().split('\n')

        # tokenize and check that indices are correct
        nlp = spacy.load('en_core_web_sm')

        # make a batch of two sentences
        batch_indices = []
        batch_lengths = []
        for k, sentence in enumerate(sentences):
            tokens = [
                token.text for token in nlp(text_standardize(sentence))
                if not token.is_space
            ]
            indices = indexer.tokens_to_indices(
                [Token(token) for token in tokens], Vocabulary(),
                'openai_indexer')
            batch_indices.append(indices['openai_indexer'])
            batch_lengths.append(
                len([i for i in indices['openai_indexer'] if i != 0]))
        batch_indices = torch.from_numpy(numpy.array(batch_indices))
        batch_size, num_timesteps = batch_indices.size()
        vocab_size = transformer.vocab_size - transformer.n_ctx
        positional_encodings = get_range_vector(num_timesteps,
                                                device=-1) + vocab_size

        # Combine the inputs with positional encodings
        batch_tensor = torch.stack(
            [
                batch_indices,  # (batch_size, num_timesteps)
                positional_encodings.expand(batch_size, num_timesteps)
            ],
            dim=-1)

        # run the LM
        transformer.eval()
        activations = transformer(batch_tensor)

        # load the expected activations
        expected_activations = []
        with h5py.File(
                self.FIXTURES_ROOT / 'openai_transformer' /
                'expected_embeddings.hdf5', 'r') as fin:
            expected_activations.append(fin['0'][...])
            expected_activations.append(fin['1'][...])

        # just check the top layer
        for k in range(2):
            actual = activations[-1][k, :batch_lengths[k], :].numpy()
            expected = expected_activations[k]
            numpy.testing.assert_almost_equal(expected, actual, decimal=5)
def get_select_embedding(sub_words_embedding, offsets):
    # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
    offsets2d = util.combine_initial_dims(offsets)
    # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
    range_vector = util.get_range_vector(offsets2d.size(0),
                                         device=util.get_device_of(sub_words_embedding)).unsqueeze(1)
    # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
    selected_embeddings = sub_words_embedding[range_vector, offsets2d]

    return util.uncombine_initial_dims(selected_embeddings, offsets.size())
예제 #20
0
def attention_bias_proximal(length, device=-1):
    """Bias for self-attention to encourage attention to close positions.
	  Args:
	    length: an integer scalar.
	  Returns:
	    a Tensor with shape [1, 1, length, length]
	"""
    r = util.get_range_vector(length, device).float()
    diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
    return torch.unsqueeze(torch.unsqueeze(-torch.log(1 + torch.abs(diff)), 0),
                           0)
예제 #21
0
파일: coref.py 프로젝트: belindal/allennlp
    def _coarse_to_fine_pruning(
        self, top_span_embeddings: torch.FloatTensor,
        top_span_mention_scores: torch.FloatTensor, num_spans_to_keep: int,
        max_antecedents: int, device: int
    ) -> Tuple[torch.IntTensor, torch.IntTensor, torch.FloatTensor,
               torch.FloatTensor]:
        # Shape: (num_spans_to_keep)
        target_indices = util.get_range_vector(num_spans_to_keep, device)

        # Shape: (num_spans_to_keep, num_spans_to_keep)
        valid_antecedent_offsets = target_indices.unsqueeze(
            1) - target_indices.unsqueeze(0)

        # Shape: (num_spans_to_keep, num_spans_to_keep)
        valid_antecedent_log_mask = (valid_antecedent_offsets >=
                                     1).float().unsqueeze(0).log()

        # Shape: (batch_size, num_spans_to_keep, num_spans_to_keep)
        fast_antecedent_scores = top_span_mention_scores + top_span_mention_scores.squeeze(
            -1).unsqueeze(1)
        fast_antecedent_scores += valid_antecedent_log_mask

        # Shape: (batch_size, num_spans_to_keep, num_spans_to_keep)
        coarse_scores = self._compute_coarse_scores(top_span_embeddings)
        fast_antecedent_scores += coarse_scores

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        _, top_antecedent_indices = fast_antecedent_scores.topk(
            max_antecedents, -1)

        # Now we order the selected indices in increasing order with
        # respect to their indices (and hence, with respect to the
        # order they originally appeared in the ``embeddings`` tensor).
        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        top_antecedent_indices, _ = torch.sort(top_antecedent_indices, dim=-1)

        # Shape: (batch_size, num_items_to_keep, max_antecedents)
        # (batch_size, num_spans_to_keep, max_antecedents)
        valid_antecedent_log_mask = valid_antecedent_log_mask.expand(
            top_antecedent_indices.size(0), -1, -1)
        top_antecedent_log_mask = torch.gather(valid_antecedent_log_mask, -1,
                                               top_antecedent_indices)

        # Shape: (batch_size, num_items_to_keep, max_antecedents)
        valid_antecedent_offsets = \
            valid_antecedent_offsets.unsqueeze(0).expand(top_antecedent_indices.size(0), -1, -1)
        top_antecedent_offsets = torch.gather(valid_antecedent_offsets, -1,
                                              top_antecedent_indices)

        # Shape: (batch_size, num_items_to_keep, max_antecedents)
        top_fast_antecedent_scores = torch.gather(fast_antecedent_scores, -1,
                                                  top_antecedent_indices)
        return top_antecedent_indices, top_antecedent_offsets, top_antecedent_log_mask, top_fast_antecedent_scores
예제 #22
0
def gather_indexes(sequence_tensor, positions):
    """Gathers the vectors at the specific positions over a minibatch."""
    sequence_shape = sequence_tensor.size()
    batch_size = sequence_shape[0]
    seq_length = sequence_shape[1]
    width = sequence_shape[2]
    flat_offsets = util.get_range_vector(
        batch_size, util.get_device_of(sequence_tensor)) * seq_length
    flat_offsets = flat_offsets.unsqueeze(-1).long()
    flat_positions = (positions + flat_offsets).view(-1)
    flat_sequence_tensor = sequence_tensor.view(batch_size * seq_length, width)
    output_tensor = torch.index_select(flat_sequence_tensor, 0, flat_positions)
    return output_tensor
예제 #23
0
    def forward(self,
                input_ids: torch.LongTensor,
                offsets: torch.LongTensor = None,
                token_type_ids: torch.LongTensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        input_ids : ``torch.LongTensor``
            The (batch_size, max_sequence_length) tensor of wordpiece ids.
        offsets : ``torch.LongTensor``, optional
            The BERT embeddings are one per wordpiece. However it's possible/likely
            you might want one per original token. In that case, ``offsets``
            represents the indices of the desired wordpiece for each original token.
            Depending on how your token indexer is configured, this could be the
            position of the last wordpiece for each token, or it could be the position
            of the first wordpiece for each token.

            For example, if you had the sentence "Definitely not", and if the corresponding
            wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
            would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
            If offsets are provided, the returned tensor will contain only the wordpiece
            embeddings at those positions, and (in particular) will contain one embedding
            per token. If offsets are not provided, the entire tensor of wordpiece embeddings
            will be returned.
        token_type_ids : ``torch.LongTensor``, optional
            If an input consists of two sentences (as in the BERT paper),
            tokens from the first sentence should have type 0 and tokens from
            the second sentence should have type 1.  If you don't provide this
            (the default BertIndexer doesn't) then it's assumed to be all 0s.
        """
        # pylint: disable=arguments-differ
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        input_mask = (input_ids != 0).long()

        all_encoder_layers, _ = self.bert_model(input_ids, input_mask,
                                                token_type_ids)
        if self._scalar_mix is not None:
            mix = self._scalar_mix(all_encoder_layers, input_mask)
        else:
            mix = all_encoder_layers[-1]

        if offsets is None:
            return mix
        else:
            batch_size = input_ids.size(0)
            range_vector = util.get_range_vector(
                batch_size, device=util.get_device_of(mix)).unsqueeze(1)
            return mix[range_vector, offsets]
예제 #24
0
    def forward(self,
                input_ids: torch.LongTensor,
                offsets: torch.LongTensor = None,
                token_type_ids: torch.LongTensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        input_ids : ``torch.LongTensor``
            The (batch_size, max_sequence_length) tensor of wordpiece ids.
        offsets : ``torch.LongTensor``, optional
            The BERT embeddings are one per wordpiece. However it's possible/likely
            you might want one per original token. In that case, ``offsets``
            represents the indices of the desired wordpiece for each original token.
            Depending on how your token indexer is configured, this could be the
            position of the last wordpiece for each token, or it could be the position
            of the first wordpiece for each token.

            For example, if you had the sentence "Definitely not", and if the corresponding
            wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
            would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
            If offsets are provided, the returned tensor will contain only the wordpiece
            embeddings at those positions, and (in particular) will contain one embedding
            per token. If offsets are not provided, the entire tensor of wordpiece embeddings
            will be returned.
        token_type_ids : ``torch.LongTensor``, optional
            If an input consists of two sentences (as in the BERT paper),
            tokens from the first sentence should have type 0 and tokens from
            the second sentence should have type 1.  If you don't provide this
            (the default BertIndexer doesn't) then it's assumed to be all 0s.
        """
        # pylint: disable=arguments-differ
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        input_mask = (input_ids != 0).long()

        all_encoder_layers, _ = self.bert_model(input_ids, input_mask, token_type_ids)
        if self._scalar_mix is not None:
            mix = self._scalar_mix(all_encoder_layers, input_mask)
        else:
            mix = all_encoder_layers[-1]


        if offsets is None:
            return mix
        else:
            batch_size = input_ids.size(0)
            range_vector = util.get_range_vector(batch_size,
                                                 device=util.get_device_of(mix)).unsqueeze(1)
            return mix[range_vector, offsets]
def get_evd_prediction_mask(all_predictions, eos_idx):
    # get the mask w.r.t to ``all_predictions`` that includes the index of the first eos and those before it
    # shape(all_predictions): (batch_size, ..., num_steps)
    # Shape: (batch_size,)
    batch_size = all_predictions.size(0)
    num_steps = all_predictions.size(-1)
    # shape: (batch_size, ...)
    valid_decoding_lens = torch.sum(
        (all_predictions != eos_idx).float(), dim=-1) + 1
    indices = get_range_vector(num_steps,
                               get_device_of(all_predictions)).float()
    mask = (indices.view(*([1] * (all_predictions.dim() - 1)), num_steps) <
            valid_decoding_lens.unsqueeze(-1)).int()
    eos_mask = (all_predictions == eos_idx).int() * mask
    return mask, eos_mask
    def _get_head_tags(self,
                       head_tag_representation: torch.Tensor,
                       child_tag_representation: torch.Tensor,
                       head_indices: torch.Tensor) -> torch.Tensor:
        """
        Decodes the head tags given the head and child tag representations
        and a tensor of head indices to compute tags for. Note that these are
        either gold or predicted heads, depending on whether this function is
        being called to compute the loss, or if it's being called during inference.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length). The indices of the heads
            for every word.

        Returns
        -------
        head_tag_logits : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            representing logits for predicting a distribution over tags
            for each arc.
        """
        batch_size = head_tag_representation.size(0)
        # shape (batch_size,)
        range_vector = get_range_vector(batch_size, get_device_of(head_tag_representation)).unsqueeze(1)

        # This next statement is quite a complex piece of indexing, which you really
        # need to read the docs to understand. See here:
        # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing
        # In effect, we are selecting the indices corresponding to the heads of each word from the
        # sequence length dimension for each element in the batch.

        # shape (batch_size, sequence_length, tag_representation_dim)
        selected_head_tag_representations = head_tag_representation[range_vector, head_indices]
        selected_head_tag_representations = selected_head_tag_representations.contiguous()
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self.tag_bilinear(selected_head_tag_representations,
                                            child_tag_representation)
        return head_tag_logits
예제 #27
0
파일: DM.py 프로젝트: jgontrum/am-parser
    def label_scores(self, encoded_text: torch.Tensor,
                     head_indices: torch.Tensor) -> torch.Tensor:
        """
        Computes edge label scores for a fixed tree structure (given by head_indices) for a batch of sentences.

        Parameters
        ----------
        encoded_text: (batch_size, sequence_length, encoder_output_dim)

        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length). The indices of the heads
            for every word (predicted or gold).

        Returns
        -------
        head_tag_logits : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            representing logits for predicting a distribution over tags
            for each arc.
        """
        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(
            self.head_tag_feedforward(encoded_text)
        )  # will be used to generate predictions for the edge labels for the given arcs.
        child_tag_representation = self._dropout(
            self.child_tag_feedforward(encoded_text)
        )  # will be used to generate predictions for the edge labels for the given arcs.

        batch_size = head_tag_representation.size(0)
        # shape (batch_size,)
        range_vector = get_range_vector(
            batch_size, get_device_of(head_tag_representation)).unsqueeze(1)

        # This next statement is quite a complex piece of indexing, which you really
        # need to read the docs to understand. See here:
        # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing
        # In effect, we are selecting the indices corresponding to the heads of each word from the
        # sequence length dimension for each element in the batch.

        # shape (batch_size, sequence_length, tag_representation_dim)
        selected_head_tag_representations = head_tag_representation[
            range_vector, head_indices]
        selected_head_tag_representations = selected_head_tag_representations.contiguous(
        )
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self.tag_bilinear(selected_head_tag_representations,
                                            child_tag_representation)
        return head_tag_logits
예제 #28
0
    def label_scores(self, encoded_text: torch.Tensor,
                     head_indices: torch.Tensor) -> torch.Tensor:
        """
        Computes edge label scores for a fixed tree structure (given by head_indices) for a batch of sentences.

        Parameters
        ----------
         encoded_text : torch.Tensor, required
            The input sentence, with artifical root node (head sentinel) added in the beginning of
            shape (batch_size, sequence length, encoding dim)
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length). The indices of the heads
            for every word (predicted or gold).

        Returns
        -------
        edge_label_logits : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            representing logits for predicting a distribution over tags
            for each given arc.
        """
        # shape (batch_size, sequence_length, tag_representation_dim)
        head_label_representation = self.head_label_feedforward(encoded_text)
        child_label_representation = self.child_label_feedforward(encoded_text)

        batch_size = head_label_representation.size(0)
        # shape (batch_size,)
        range_vector = get_range_vector(
            batch_size, get_device_of(head_label_representation)).unsqueeze(1)

        # This next statement is quite a complex piece of indexing, which you really
        # need to read the docs to understand. See here:
        # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing
        # In effect, we are selecting the indices corresponding to the heads of each word from the
        # sequence length dimension for each element in the batch.

        # shape (batch_size, sequence_length, tag_representation_dim)
        selected_head_label_representations = head_label_representation[
            range_vector, head_indices]
        selected_head_label_representations = selected_head_label_representations.contiguous(
        )

        combined = self.activation(selected_head_label_representations +
                                   child_label_representation)
        #(batch_size, sequence_length, num_head_tags)
        edge_label_logits = self.label_out_layer(combined)

        return edge_label_logits
    def test_openai_transformer_matches_tensorflow(self):
        model_path = "https://s3-us-west-2.amazonaws.com/allennlp/models/openai-transformer-lm-2018.07.23.tar.gz"
        indexer = OpenaiTransformerBytePairIndexer(model_path=model_path)
        transformer = OpenaiTransformer(model_path=model_path)

        # get the test sentences
        with open(self.FIXTURES_ROOT / 'openai_transformer' / 'text.txt', 'r') as fin:
            sentences = fin.read().strip().split('\n')

        # tokenize and check that indices are correct
        nlp = spacy.load('en_core_web_sm')

        # make a batch of two sentences
        batch_indices = []
        batch_lengths = []
        for k, sentence in enumerate(sentences):
            tokens = [token.text for token in nlp(text_standardize(sentence)) if not token.is_space]
            indices = indexer.tokens_to_indices(
                    [Token(token) for token in tokens], Vocabulary(), 'openai_indexer'
            )
            batch_indices.append(indices['openai_indexer'])
            batch_lengths.append(len([i for i in indices['openai_indexer'] if i != 0]))
        batch_indices = torch.from_numpy(numpy.array(batch_indices))
        batch_size, num_timesteps = batch_indices.size()
        vocab_size = transformer.vocab_size - transformer.n_ctx
        positional_encodings = get_range_vector(num_timesteps, device=-1) + vocab_size

        # Combine the inputs with positional encodings
        batch_tensor = torch.stack([
                batch_indices,   # (batch_size, num_timesteps)
                positional_encodings.expand(batch_size, num_timesteps)
        ], dim=-1)

        # run the LM
        transformer.eval()
        activations = transformer(batch_tensor)

        # load the expected activations
        expected_activations = []
        with h5py.File(self.FIXTURES_ROOT / 'openai_transformer' / 'expected_embeddings.hdf5', 'r') as fin:
            expected_activations.append(fin['0'][...])
            expected_activations.append(fin['1'][...])

        # just check the top layer
        for k in range(2):
            actual = activations[-1][k, :batch_lengths[k], :].numpy()
            expected = expected_activations[k]
            numpy.testing.assert_almost_equal(expected, actual, decimal=5)
예제 #30
0
    def get_input_type_ids(self, type_ids, offsets, embedder):
        "Converts (bsz, seq_len_wp) to (bsz, seq_len_wp) by indexing."
        batch_size = type_ids.size(0)
        full_seq_len = type_ids.size(1)
        if full_seq_len > embedder.max_pieces:  # Recombine if we had used sliding window approach
            assert batch_size == 1 and type_ids.max() > 0
            num_question_tokens = type_ids[0][:embedder.max_pieces].nonzero(
            ).size(0)
            select_indices = embedder.indices_to_select(
                full_seq_len, num_question_tokens)
            type_ids = type_ids[:, select_indices]

        range_vector = util.get_range_vector(
            batch_size, device=util.get_device_of(type_ids)).unsqueeze(1)
        type_ids = type_ids[range_vector, offsets]
        return type_ids
예제 #31
0
def number2count_auxloss(passage_number_values: List[List[float]],
                         device_id=-1):
    """ Using passage numnbers, make a (batch_size, max_passage_numbers) (padded) tensor, each containing a
        noisy distribution with mass distributed over x-numbers. The corresponding count-answer will be x.
        Use the attention2count rnn to predict a count value and compute the loss.
    """
    batch_size = len(passage_number_values)
    # List of length -- batch-size
    num_of_passage_numbers = [len(nums) for nums in passage_number_values]
    max_passage_numbers = max(num_of_passage_numbers)

    # Shape: (batch_size, )
    num_pasasge_numbers = util.move_to_device(
        torch.LongTensor(num_of_passage_numbers), cuda_device=device_id)
    # Shape: (max_passage_numbers, )
    range_vector = util.get_range_vector(size=max_passage_numbers,
                                         device=device_id)

    mask = (range_vector.unsqueeze(0) <
            num_pasasge_numbers.unsqueeze(1)).float()
    print(mask)

    number_distributions = mask.new_zeros(batch_size,
                                          max_passage_numbers).normal_(
                                              0, 0.01).abs_()
    count_answers = number_distributions.new_zeros(batch_size,
                                                   max_passage_numbers).long()

    for i, num_numbers in enumerate(num_of_passage_numbers):
        """ Sample a count value between [0, min(5, num_numbers)]. Sample indices in this range, and set them as 1.
            Add gaussian noise to the whole tensor and normalize. 
        """
        # Pick a count answer
        count_value = random.randint(0, min(7, num_numbers))
        count_answers[i, count_value] = 1
        # Pick the indices that will have mass
        if count_value > 0:
            indices = random.sample(range(num_numbers), count_value)
            # Add 1.0 to all sampled indices
            number_distributions[i, indices] += 1.0

    number_distributions = number_distributions * mask
    number_distributions = number_distributions / torch.sum(
        number_distributions, dim=1).unsqueeze(1)
예제 #32
0
    def common_step(self, batch, phase="train"):
        (token_ids, type_ids, offsets, wordpiece_mask, pos_tags, word_mask,
         mrc_mask, meta_data, parent_idxs,
         parent_tags) = (batch["token_ids"], batch["type_ids"],
                         batch["offsets"], batch["wordpiece_mask"],
                         batch["pos_tags"], batch["word_mask"],
                         batch["mrc_mask"], batch["meta_data"],
                         batch["parent_idxs"], batch["parent_tags"])
        parent_probs, parent_tag_probs, parent_arc_nll, parent_tag_nll = self(
            token_ids, type_ids, offsets, wordpiece_mask, pos_tags, word_mask,
            mrc_mask, parent_idxs, parent_tags)
        loss = parent_arc_nll + parent_tag_nll
        eval_mask = self._get_mask_for_eval(mask=word_mask, pos_tags=pos_tags)
        bsz = parent_probs.size(0)
        # [bsz]
        batch_range_vector = get_range_vector(bsz, get_device_of(parent_tags))
        eval_mask = eval_mask[batch_range_vector, parent_idxs]  # [bsz]
        if phase == "train" or not self.args.use_mst:
            # [bsz]
            pred_positions = parent_probs.argmax(1)
            metric_name = f"{phase}_stat"
            metric = getattr(self, metric_name)
            metric.update(
                pred_positions.unsqueeze(-1),  # [bsz, 1]
                parent_tag_probs[batch_range_vector, pred_positions].argmax(
                    1).unsqueeze(-1),  # [bsz, 1]
                parent_idxs.unsqueeze(-1),  # [bsz, 1]
                parent_tags.unsqueeze(-1),  # [bsz, 1]
                eval_mask.unsqueeze(-1)  # [bsz, 1]
            )
        else:  # todo implement mst decoding
            metric = getattr(self, f"{phase}_stat")
            metric.update(meta_data["ann_idx"], meta_data["word_idx"],
                          [len(x) for x in meta_data["words"]], parent_probs,
                          parent_tag_probs, eval_mask)

        # acc_metric = getattr(self, f"{phase}_acc")
        # acc_metric.update(
        #     preds=is_subtree_probs,
        #     target=is_subtree
        # )

        self.log(f'{phase}_loss', loss)
        return loss
def flatten_and_batch_shift_indices(indices: torch.Tensor,
                                    sequence_length: int) -> torch.Tensor:
    """
    This is a subroutine for :func:`~batched_index_select`. The given ``indices`` of size
    ``(batch_size, d_1, ..., d_n)`` indexes into dimension 2 of a target tensor, which has size
    ``(batch_size, sequence_length, embedding_size)``. This function returns a vector that
    correctly indexes into the flattened target. The sequence length of the target must be
    provided to compute the appropriate offsets.

    .. code-block:: python

        indices = torch.ones([2,3], dtype=torch.long)
        # Sequence length of the target tensor.
        sequence_length = 10
        shifted_indices = flatten_and_batch_shift_indices(indices, sequence_length)
        # Indices into the second element in the batch are correctly shifted
        # to take into account that the target tensor will be flattened before
        # the indices are applied.
        assert shifted_indices == [1, 1, 1, 11, 11, 11]

    Parameters
    ----------
    indices : ``torch.LongTensor``, required.
    sequence_length : ``int``, required.
        The length of the sequence the indices index into.
        This must be the second dimension of the tensor.

    Returns
    -------
    offset_indices : ``torch.LongTensor``
    """
    # Shape: (batch_size)
    offsets = get_range_vector(indices.size(0),
                               get_device_of(indices)) * sequence_length
    for _ in range(len(indices.size()) - 1):
        offsets = offsets.unsqueeze(1)

    # Shape: (batch_size, d_1, ..., d_n)
    offset_indices = indices + offsets
    # print(offset_indices)
    # Shape: (batch_size * d_1 * ... * d_n)
    offset_indices = offset_indices.view(-1)
    return offset_indices
예제 #34
0
    def forward(self,
                inputs: torch.Tensor,
                mask: torch.Tensor,
                span: torch.Tensor) -> torch.Tensor:
        # pylint: disable=arguments-differ,unused-argument

        # input -> [B x seq_len x d], offset -> [B x 2]
        batch_size, seq_len, _ = inputs.size()

        offset = span[:, 0].unsqueeze(-1)
        position_range = util.get_range_vector(
                seq_len, util.get_device_of(inputs)).repeat((batch_size, 1))

        offset_mask = position_range == offset

        position_markers = inputs.new_ones((batch_size, seq_len), requires_grad=True)
        position_markers = position_markers * offset_mask.float()
        position_markers = position_markers.unsqueeze(-1)

        return position_markers
    def _construct_loss(self,
                        head_tag_representation: torch.Tensor,
                        child_tag_representation: torch.Tensor,
                        attended_arcs: torch.Tensor,
                        head_indices: torch.Tensor,
                        head_tags: torch.Tensor,
                        mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the arc and tag loss for a sequence given gold head indices and tags.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachements of a given word to all other words.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The indices of the heads for every word.
        head_tags : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The dependency labels of the heads for every word.
        mask : ``torch.Tensor``, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        Returns
        -------
        arc_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc loss.
        tag_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc tag loss.
        """
        float_mask = mask.float()
        batch_size, sequence_length, _ = attended_arcs.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        normalised_arc_logits = masked_log_softmax(attended_arcs,
                                                   mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)

        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices)
        normalised_head_tag_logits = masked_log_softmax(head_tag_logits,
                                                        mask.unsqueeze(-1)) * float_mask.unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs))
        child_index = timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long()
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index, head_indices]
        tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum() / valid_positions.float()
        tag_nll = -tag_loss.sum() / valid_positions.float()
        return arc_nll, tag_nll
예제 #36
0
    def _generate_valid_antecedents(num_spans_to_keep: int,
                                    max_antecedents: int,
                                    device: int) -> Tuple[torch.IntTensor,
                                                          torch.IntTensor,
                                                          torch.FloatTensor]:
        """
        This method generates possible antecedents per span which survived the pruning
        stage. This procedure is `generic across the batch`. The reason this is the case is
        that each span in a batch can be coreferent with any previous span, but here we
        are computing the possible `indices` of these spans. So, regardless of the batch,
        the 1st span _cannot_ have any antecedents, because there are none to select from.
        Similarly, each element can only predict previous spans, so this returns a matrix
        of shape (num_spans_to_keep, max_antecedents), where the (i,j)-th index is equal to
        (i - 1) - j if j <= i, or zero otherwise.

        Parameters
        ----------
        num_spans_to_keep : ``int``, required.
            The number of spans that were kept while pruning.
        max_antecedents : ``int``, required.
            The maximum number of antecedent spans to consider for every span.
        device: ``int``, required.
            The CUDA device to use.

        Returns
        -------
        valid_antecedent_indices : ``torch.IntTensor``
            The indices of every antecedent to consider with respect to the top k spans.
            Has shape ``(num_spans_to_keep, max_antecedents)``.
        valid_antecedent_offsets : ``torch.IntTensor``
            The distance between the span and each of its antecedents in terms of the number
            of considered spans (i.e not the word distance between the spans).
            Has shape ``(1, max_antecedents)``.
        valid_antecedent_log_mask : ``torch.FloatTensor``
            The logged mask representing whether each antecedent span is valid. Required since
            different spans have different numbers of valid antecedents. For example, the first
            span in the document should have no valid antecedents.
            Has shape ``(1, num_spans_to_keep, max_antecedents)``.
        """
        # Shape: (num_spans_to_keep, 1)
        target_indices = util.get_range_vector(num_spans_to_keep, device).unsqueeze(1)

        # Shape: (1, max_antecedents)
        valid_antecedent_offsets = (util.get_range_vector(max_antecedents, device) + 1).unsqueeze(0)

        # This is a broadcasted subtraction.
        # Shape: (num_spans_to_keep, max_antecedents)
        raw_antecedent_indices = target_indices - valid_antecedent_offsets

        # In our matrix of indices, the upper triangular part will be negative
        # because the offsets will be > the target indices. We want to mask these,
        # because these are exactly the indices which we don't want to predict, per span.
        # We're generating a logspace mask here because we will eventually create a
        # distribution over these indices, so we need the 0 elements of the mask to be -inf
        # in order to not mess up the normalisation of the distribution.
        # Shape: (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_log_mask = (raw_antecedent_indices >= 0).float().unsqueeze(0).log()

        # Shape: (num_spans_to_keep, max_antecedents)
        valid_antecedent_indices = F.relu(raw_antecedent_indices.float()).long()
        return valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask
예제 #37
0
    def forward(self,
                input_ids: torch.LongTensor,
                offsets: torch.LongTensor = None,
                token_type_ids: torch.LongTensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        input_ids : ``torch.LongTensor``
            The (batch_size, ..., max_sequence_length) tensor of wordpiece ids.
        offsets : ``torch.LongTensor``, optional
            The BERT embeddings are one per wordpiece. However it's possible/likely
            you might want one per original token. In that case, ``offsets``
            represents the indices of the desired wordpiece for each original token.
            Depending on how your token indexer is configured, this could be the
            position of the last wordpiece for each token, or it could be the position
            of the first wordpiece for each token.

            For example, if you had the sentence "Definitely not", and if the corresponding
            wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
            would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
            If offsets are provided, the returned tensor will contain only the wordpiece
            embeddings at those positions, and (in particular) will contain one embedding
            per token. If offsets are not provided, the entire tensor of wordpiece embeddings
            will be returned.
        token_type_ids : ``torch.LongTensor``, optional
            If an input consists of two sentences (as in the BERT paper),
            tokens from the first sentence should have type 0 and tokens from
            the second sentence should have type 1.  If you don't provide this
            (the default BertIndexer doesn't) then it's assumed to be all 0s.
        """
        # pylint: disable=arguments-differ
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        input_mask = (input_ids != 0).long()

        # input_ids may have extra dimensions, so we reshape down to 2-d
        # before calling the BERT model and then reshape back at the end.
        all_encoder_layers, _ = self.bert_model(input_ids=util.combine_initial_dims(input_ids),
                                                token_type_ids=util.combine_initial_dims(token_type_ids),
                                                attention_mask=util.combine_initial_dims(input_mask))
        if self._scalar_mix is not None:
            mix = self._scalar_mix(all_encoder_layers, input_mask)
        else:
            mix = all_encoder_layers[-1]

        # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)

        if offsets is None:
            # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim)
            return util.uncombine_initial_dims(mix, input_ids.size())
        else:
            # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
            offsets2d = util.combine_initial_dims(offsets)
            # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
            range_vector = util.get_range_vector(offsets2d.size(0),
                                                 device=util.get_device_of(mix)).unsqueeze(1)
            # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
            selected_embeddings = mix[range_vector, offsets2d]

            return util.uncombine_initial_dims(selected_embeddings, offsets.size())
예제 #38
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                p1_answer_marker: torch.IntTensor = None,
                p2_answer_marker: torch.IntTensor = None,
                p3_answer_marker: torch.IntTensor = None,
                yesno_list: torch.IntTensor = None,
                followup_list: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(followup_list, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question, num_wrapping_dims=1)
        embedded_question = embedded_question.reshape(total_qa_count, max_q_len,
                                                      self._text_field_embedder.get_output_dim())
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(self._text_field_embedder(passage))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(question_num_ind)
            embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \
                view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length)
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker)
                    repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb],
                                                          dim=-1)

            repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage,
                                                                                    repeated_passage_mask))
        else:
            encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count,
                                                                     passage_length,
                                                                     self._encoding_dim)

        encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)

        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count,
                                                                                    passage_length,
                                                                                    self._encoding_dim)

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([repeated_encoded_passage,
                                          passage_question_vectors,
                                          repeated_encoded_passage * passage_question_vectors,
                                          repeated_encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage,
                                                                          repeated_passage_mask))
        self_attention_matrix = self._self_attention(residual_layer, residual_layer)

        mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \
               * repeated_passage_mask.reshape(total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs, residual_layer)
        self_attention_vecs = torch.cat([self_attention_vecs, residual_layer,
                                         residual_layer * self_attention_vecs],
                                        dim=-1)
        residual_layer = F.relu(self._merge_self_attention(self_attention_vecs))

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1),
                                         repeated_passage_mask)
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1)

        span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7)
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits,
                                                       span_yesno_logits, span_followup_logits,
                                                       self._max_span_length)

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1),
                            ignore_index=-1)
            self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask), span_end.view(-1), ignore_index=-1)
            self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end], -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1)
            loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask)
            self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['followup'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index])

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_followup_list.append(followup_pred)
                per_dialog_query_id_list.append(iid)

                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                                 best_span_string,
                                                                                 refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                            best_span_string,
                                                                            answer_texts)
                self._official_f1(100 * f1_score)
            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
            output_dict['yesno'].append(per_dialog_yesno_list)
            output_dict['followup'].append(per_dialog_followup_list)
        return output_dict
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # shape (batch_size, sequence_length, 1)
        global_attention_logits = self._global_attention(sequence_tensor)

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(max_batch_span_width,
                                                       util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices)

        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_logits = util.batched_index_select(global_attention_logits,
                                                          span_indices,
                                                          flat_span_indices).squeeze(-1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_weights = util.masked_softmax(span_attention_logits, span_mask)

        # Do a weighted sum of the embedded spans with
        # respect to the normalised attention distributions.
        # Shape: (batch_size, num_spans, embedding_dim)
        attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights)

        if span_indices_mask is not None:
            # Above we were masking the widths of spans with respect to the max
            # span width in the batch. Here we are masking the spans which were
            # originally passed in as padding.
            return attended_text_embeddings * span_indices_mask.unsqueeze(-1).float()

        return attended_text_embeddings