def forward(self, inputs: torch.Tensor, offsets: torch.Tensor = None) -> torch.Tensor:
        inputs: ``torch.Tensor``, required
            A ``(batch_size, num_timesteps)`` tensor representing the byte-pair encodings
            for the current batch.
        offsets: ``torch.Tensor``, required
            A ``(batch_size, max_sequence_length)`` tensor representing the word offsets
            for the current batch.

            An embedding representation of the input sequence
            having shape ``(batch_size, sequence_length, embedding_dim)``
        # pylint: disable=arguments-differ
        batch_size, num_timesteps = inputs.size()

        # the transformer embedding consists of the byte pair embeddings,
        # the special embeddings and the position embeddings.
        # the position embeddings are always at least self._transformer.n_ctx,
        # but may be longer.
        # the transformer "vocab" consists of the actual vocab and the
        # positional encodings. Here we want the count of just the former.
        vocab_size = self._transformer.vocab_size - self._transformer.n_ctx

        # vocab_size, vocab_size + 1, ...
        positional_encodings = get_range_vector(num_timesteps, device=get_device_of(inputs)) + vocab_size

        # Combine the inputs with positional encodings
        batch_tensor = torch.stack([
                inputs,   # (batch_size, num_timesteps)
                positional_encodings.expand(batch_size, num_timesteps)
        ], dim=-1)

        byte_pairs_mask = inputs != 0

        # Embeddings is num_output_layers x (batch_size, num_timesteps, embedding_dim)
        layer_activations = self._transformer(batch_tensor)

        # Output of scalar_mix is (batch_size, num_timesteps, embedding_dim)
        if self._top_layer_only:
            mix = layer_activations[-1]
            mix = self._scalar_mix(layer_activations, byte_pairs_mask)

        # These embeddings are one per byte-pair, but we want one per original _word_.
        # So we choose the embedding corresponding to the last byte pair for each word,
        # which is captured by the ``offsets`` input.
        if offsets is not None:
            range_vector = get_range_vector(batch_size, device=get_device_of(mix)).unsqueeze(1)
            last_byte_pair_embeddings = mix[range_vector, offsets]
            # allow to return all byte pairs by passing no offsets
            seq_len = (byte_pairs_mask > 0).long().sum(dim=1).max()
            last_byte_pair_embeddings = mix[:, :seq_len]

        return last_byte_pair_embeddings
    def create_cached_cnn_embeddings(self, tokens: List[str]) -> None:
        Given a list of tokens, this method precomputes word representations
        by running just the character convolutions and highway layers of elmo,
        essentially creating uncontextual word vectors. On subsequent forward passes,
        the word ids are looked up from an embedding, rather than being computed on
        the fly via the CNN encoder.

        This function sets 3 attributes:

        _word_embedding : ``torch.Tensor``
            The word embedding for each word in the tokens passed to this method.
        _bos_embedding : ``torch.Tensor``
            The embedding for the BOS token.
        _eos_embedding : ``torch.Tensor``
            The embedding for the EOS token.

        tokens : ``List[str]``, required.
            A list of tokens to precompute character convolutions for.
        tokens = [ELMoCharacterMapper.bos_token, ELMoCharacterMapper.eos_token] + tokens
        timesteps = 32
        batch_size = 32
        chunked_tokens = lazy_groups_of(iter(tokens), timesteps)

        all_embeddings = []
        device = get_device_of(next(self.parameters()))
        for batch in lazy_groups_of(chunked_tokens, batch_size):
            # Shape (batch_size, timesteps, 50)
            batched_tensor = batch_to_ids(batch)
            # NOTE: This device check is for when a user calls this method having
            # already placed the model on a device. If this is called in the
            # constructor, it will probably happen on the CPU. This isn't too bad,
            # because it's only a few convolutions and will likely be very fast.
            if device >= 0:
                batched_tensor = batched_tensor.cuda(device)
            output = self._token_embedder(batched_tensor)
            token_embedding = output["token_embedding"]
            mask = output["mask"]
            token_embedding, _ = remove_sentence_boundaries(token_embedding, mask)
            all_embeddings.append(token_embedding.view(-1, token_embedding.size(-1)))
        full_embedding =, 0)

        # We might have some trailing embeddings from padding in the batch, so
        # we clip the embedding and lookup to the right size.
        full_embedding = full_embedding[:len(tokens), :]
        embedding = full_embedding[2:len(tokens), :]
        vocab_size, embedding_dim = list(embedding.size())

        from allennlp.modules.token_embedders import Embedding # type: ignore
        self._bos_embedding = full_embedding[0, :]
        self._eos_embedding = full_embedding[1, :]
        self._word_embedding = Embedding(vocab_size, # type: ignore
    def _create_grammar_state(self, possible_actions: List[ProductionRule]) -> GrammarStatelet:
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        device = util.get_device_of(self._action_embedder.weight)
        # TODO(Mark): This type is pure \(- . ^)/
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {}

        actions_grouped_by_nonterminal: Dict[str, List[Tuple[ProductionRule, int]]] = defaultdict(list)
        for i, action in enumerate(possible_actions):
            if action.rule == "":
            if action.is_global_rule:
                actions_grouped_by_nonterminal[action.nonterminal].append((action, i))
                raise ValueError("The sql parser doesn't support non-global actions yet.")

        for key, production_rule_arrays in actions_grouped_by_nonterminal.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            global_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                global_actions.append((production_rule_array.rule_id, action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor =, dim=0).long()
                if device >= 0:
                    global_action_tensor =

                global_input_embeddings = self._action_embedder(global_action_tensor)
                global_output_embeddings = self._output_action_embedder(global_action_tensor)

                translated_valid_actions[key]['global'] = (global_input_embeddings,
        return GrammarStatelet(['statement'],
    def forward(self,
                input_ids: torch.LongTensor,
                offsets: torch.LongTensor = None,
                token_type_ids: torch.LongTensor = None) -> torch.Tensor:
        input_ids : ``torch.LongTensor``
            The (batch_size, max_sequence_length) tensor of wordpiece ids.
        offsets : ``torch.LongTensor``, optional
            The BERT embeddings are one per wordpiece. However it's possible/likely
            you might want one per original token. In that case, ``offsets``
            represents the indices of the desired wordpiece for each original token.
            Depending on how your token indexer is configured, this could be the
            position of the last wordpiece for each token, or it could be the position
            of the first wordpiece for each token.

            For example, if you had the sentence "Definitely not", and if the corresponding
            wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
            would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
            If offsets are provided, the returned tensor will contain only the wordpiece
            embeddings at those positions, and (in particular) will contain one embedding
            per token. If offsets are not provided, the entire tensor of wordpiece embeddings
            will be returned.
        token_type_ids : ``torch.LongTensor``, optional
            If an input consists of two sentences (as in the BERT paper),
            tokens from the first sentence should have type 0 and tokens from
            the second sentence should have type 1.  If you don't provide this
            (the default BertIndexer doesn't) then it's assumed to be all 0s.
        # pylint: disable=arguments-differ
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        input_mask = (input_ids != 0).long()

        all_encoder_layers, _ = self.bert_model(input_ids, input_mask, token_type_ids)
        if self._scalar_mix is not None:
            mix = self._scalar_mix(all_encoder_layers, input_mask)
            mix = all_encoder_layers[-1]

        if offsets is None:
            return mix
            batch_size = input_ids.size(0)
            range_vector = util.get_range_vector(batch_size,
            return mix[range_vector, offsets]
    def _get_head_tags(self,
                       head_tag_representation: torch.Tensor,
                       child_tag_representation: torch.Tensor,
                       head_indices: torch.Tensor) -> torch.Tensor:
        Decodes the head tags given the head and child tag representations
        and a tensor of head indices to compute tags for. Note that these are
        either gold or predicted heads, depending on whether this function is
        being called to compute the loss, or if it's being called during inference.

        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length). The indices of the heads
            for every word.

        head_tag_logits : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            representing logits for predicting a distribution over tags
            for each arc.
        batch_size = head_tag_representation.size(0)
        # shape (batch_size,)
        range_vector = get_range_vector(batch_size, get_device_of(head_tag_representation)).unsqueeze(1)

        # This next statement is quite a complex piece of indexing, which you really
        # need to read the docs to understand. See here:
        # In effect, we are selecting the indices corresponding to the heads of each word from the
        # sequence length dimension for each element in the batch.

        # shape (batch_size, sequence_length, tag_representation_dim)
        selected_head_tag_representations = head_tag_representation[range_vector, head_indices]
        selected_head_tag_representations = selected_head_tag_representations.contiguous()
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self.tag_bilinear(selected_head_tag_representations,
        return head_tag_logits
    def _get_prediction_device(self) -> int:
        This method checks the device of the model parameters to determine the cuda_device
        this model should be run on for predictions.  If there are no parameters, it returns -1.

        The cuda device this model should run on for predictions.
        devices = {util.get_device_of(param) for param in self.parameters()}

        if len(devices) > 1:
            devices_string = ", ".join(str(x) for x in devices)
            raise ConfigurationError(f"Parameters have mismatching cuda_devices: {devices_string}")
        elif len(devices) == 1:
            return devices.pop()
            return -1
def flatten_and_batch_shift_indices(indices: torch.Tensor,
                                    sequence_length: int) -> torch.Tensor:
    This is a subroutine for :func:`~batched_index_select`. The given ``indices`` of size
    ``(batch_size, d_1, ..., d_n)`` indexes into dimension 2 of a target tensor, which has size
    ``(batch_size, sequence_length, embedding_size)``. This function returns a vector that
    correctly indexes into the flattened target. The sequence length of the target must be
    provided to compute the appropriate offsets.

    .. code-block:: python

        indices = torch.ones([2,3], dtype=torch.long)
        # Sequence length of the target tensor.
        sequence_length = 10
        shifted_indices = flatten_and_batch_shift_indices(indices, sequence_length)
        # Indices into the second element in the batch are correctly shifted
        # to take into account that the target tensor will be flattened before
        # the indices are applied.
        assert shifted_indices == [1, 1, 1, 11, 11, 11]

    indices : ``torch.LongTensor``, required.
    sequence_length : ``int``, required.
        The length of the sequence the indices index into.
        This must be the second dimension of the tensor.

    offset_indices : ``torch.LongTensor``
    # Shape: (batch_size)
    offsets = get_range_vector(indices.size(0),
                               get_device_of(indices)) * sequence_length
    for _ in range(len(indices.size()) - 1):
        offsets = offsets.unsqueeze(1)

    # Shape: (batch_size, d_1, ..., d_n)
    offset_indices = indices + offsets
    # print(offset_indices)
    # Shape: (batch_size * d_1 * ... * d_n)
    offset_indices = offset_indices.view(-1)
    return offset_indices
    def _get_prediction_device(self) -> int:
        This method checks the device of the model parameters to determine the cuda_device
        this model should be run on for predictions.  If there are no parameters, it returns -1.

        The cuda device this model should run on for predictions.
        devices = {util.get_device_of(param) for param in self.parameters()}

        if len(devices) > 1:
            devices_string = ", ".join(str(x) for x in devices)
            raise ConfigurationError(
                f"Parameters have mismatching cuda_devices: {devices_string}")
        elif len(devices) == 1:
            return devices.pop()
            return -1
    def label_scores(self, encoded_text:torch.Tensor, head_indices: torch.Tensor) -> torch.Tensor:
        Computes edge label scores for a fixed tree structure (given by head_indices) for a batch of sentences.

         encoded_text : torch.Tensor, required
            The input sentence, with artifical root node (head sentinel) added in the beginning of
            shape (batch_size, sequence length, encoding dim)
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length). The indices of the heads
            for every word (predicted or gold).

        edge_label_logits : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            representing logits for predicting a distribution over tags
            for each given arc.
        # shape (batch_size, sequence_length, tag_representation_dim)
        head_label_representation = self._dropout(self.head_label_feedforward(encoded_text))
        child_label_representation = self._dropout(self.child_label_feedforward(encoded_text))

        batch_size = head_label_representation.size(0)
        # shape (batch_size,)
        range_vector = get_range_vector(batch_size, get_device_of(head_label_representation)).unsqueeze(1)

        # This next statement is quite a complex piece of indexing, which you really
        # need to read the docs to understand. See here:
        # In effect, we are selecting the indices corresponding to the heads of each word from the
        # sequence length dimension for each element in the batch.

        # shape (batch_size, sequence_length, tag_representation_dim)
        selected_head_label_representations = head_label_representation[range_vector, head_indices]
        selected_head_label_representations = selected_head_label_representations.contiguous()

        combined = self.activation(selected_head_label_representations + child_label_representation)
        #(batch_size, sequence_length, num_head_tags)
        edge_label_logits = self.label_out_layer(combined)

        return edge_label_logits
def get_next_sentence_output(use_fp16, input_tensor, next_sentence_feedforward,
    """Get loss and log probs for the next sentence prediction."""
    # Simple binary classification. Note that 0 is "next sentence" and 1 is
    # "random sentence". This weight matrix is not used after pre-training.
    logits = next_sentence_feedforward(input_tensor)
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    labels = labels.view(-1, 1)
    if labels.is_cuda:
        one_hot_labels = torch.cuda.FloatTensor(
            labels.size(0), 2, device=util.get_device_of(labels))
        one_hot_labels = torch.FloatTensor(labels.size(0), 2)
    if use_fp16:
        one_hot_labels = one_hot_labels.half()
    one_hot_labels.scatter_(1, labels, 1)
    per_example_loss = -(one_hot_labels * log_probs).sum(-1)
    loss = torch.sum(per_example_loss)
    return (loss, per_example_loss, log_probs)
    def forward(self,
                inputs: torch.Tensor,
                mask: torch.Tensor,
                span: torch.Tensor) -> torch.Tensor:
        # pylint: disable=arguments-differ,unused-argument

        # input -> [B x seq_len x d], offset -> [B x 2]
        batch_size, seq_len, _ = inputs.size()

        offset = span[:, 0].unsqueeze(-1)
        position_range = util.get_range_vector(
                seq_len, util.get_device_of(inputs)).repeat((batch_size, 1))

        offset_mask = position_range == offset

        position_markers = inputs.new_ones((batch_size, seq_len), requires_grad=True)
        position_markers = position_markers * offset_mask.float()
        position_markers = position_markers.unsqueeze(-1)

        return position_markers
def scoped_pool(tokens: torch.Tensor,
                mask: torch.Tensor,
                pooling: str,
                pooling_scopes: List[PoolingScope],
                is_bidirectional: bool = False,
                head: torch.Tensor = None,
                tail: torch.Tensor = None) -> torch.Tensor:
    pooling_masks = []

    if PoolingScope.SEQUENCE in pooling_scopes:

    if PoolingScope.HEAD in pooling_scopes or PoolingScope.TAIL in pooling_scopes:
        assert head is not None and tail is not None, \
            "head and tail offsets are required for pooling on entities"

        batch_size, seq_len, _ = tokens.size()
        pos_range = util.get_range_vector(
                seq_len, util.get_device_of(tokens)).repeat((batch_size, 1))

        if PoolingScope.HEAD in pooling_scopes:
            head_start = head[:, 0].unsqueeze(dim=1)
            head_end = head[:, 1].unsqueeze(dim=1)
            head_mask = ((, head_start)
                          * torch.le(pos_range, head_end)).unsqueeze(-1).long())

        if PoolingScope.TAIL in pooling_scopes:
            tail_start = tail[:, 0].unsqueeze(dim=1)
            tail_end = tail[:, 1].unsqueeze(dim=1)

            tail_mask = ((, tail_start)
                          * torch.le(pos_range, tail_end)).unsqueeze(-1).long())

    assert pooling_masks, "At least one pooling scope must be defined"

    pooled = [pool(tokens, mask, dim=1, pooling=pooling,
                   is_bidirectional=is_bidirectional) for mask in pooling_masks]

    return, dim=-1)
    def _get_head_tags(self,
                       head_tag_representation: torch.Tensor,
                       child_tag_representation: torch.Tensor,
                       head_indices: torch.Tensor) -> torch.Tensor:
        batch_size = head_tag_representation.size(0)
        # shape (batch_size,)
        range_vector = get_range_vector(batch_size, get_device_of(head_tag_representation)).unsqueeze(1)

        # This next statement is quite a complex piece of indexing, which you really
        # need to read the docs to understand. See here:
        # In effect, we are selecting the indices corresponding to the heads of each word from the
        # sequence length dimension for each element in the batch.

        # shape (batch_size, sequence_length, tag_representation_dim)
        selected_head_tag_representations = head_tag_representation[range_vector, head_indices]
        selected_head_tag_representations = selected_head_tag_representations.contiguous()
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self.tag_bilinear(selected_head_tag_representations,
        return head_tag_logits
def _convert_actionscores_to_probs(batch_actionseq_scores: List[List[torch.Tensor]]) -> List[torch.FloatTensor]:
    """ Normalize program scores in a beam for an instance to get probabilities

        For each instance, a tensor the size of number of predicted programs
        containing normalized probabilities
    # Convert batch_action_scores to a single tensor the size of number of programs for each instance
    device_id = allenutil.get_device_of(batch_actionseq_scores[0][0])
    # Inside List[torch.Tensor] is a list of scalar-tensor with prob of each program for this instance
    # The prob is normalized across the programs in the beam
    batch_actionseq_probs = []
    for score_list in batch_actionseq_scores:
        scores_astensor = allenutil.move_to_device([x.view(1) for x in score_list]), device_id)
        # allenutil.masked_softmax(scores_astensor, mask=None)
        action_probs = torch.nn.functional.softmax(scores_astensor, dim=-1)

    return batch_actionseq_probs
    def forward(self,
                inputs: torch.Tensor,
                mask: torch.Tensor,
                span: torch.Tensor) -> torch.Tensor:
        # pylint: disable=arguments-differ

        # input -> [B x seq_len x d], offset -> [B x 2]
        batch_size, seq_len, _ = inputs.size()

        offset = span[:, 0]
        position_range = util.get_range_vector(
                seq_len, util.get_device_of(inputs)).repeat((batch_size, 1))

        relative_positions = (1 + self._n_position
                              + position_range
                              - offset.unsqueeze(dim=1))

        # mask padding so it won't receive a positional embedding
        relative_positions = relative_positions * mask.long()

        return self._embedding(relative_positions)
    def common_step(self, batch, phase="train"):
        token_ids, type_ids, offsets, wordpiece_mask, span_idx, span_tag, pos_tags, word_mask, mrc_mask, meta_data, child_arcs, child_tags = (
            batch["token_ids"], batch["type_ids"], batch["offsets"],
            batch["wordpiece_mask"], batch["span_idx"], batch["span_tag"],
            batch["pos_tags"], batch["word_mask"], batch["mrc_mask"],
            batch["meta_data"], batch["child_arcs"], batch["child_tags"])
        parent_probs, parent_tag_probs, child_probs, child_tag_probs, parent_arc_nll, parent_tag_nll, child_arc_loss, child_tag_loss = self(
            token_ids, type_ids, offsets, wordpiece_mask, span_idx, span_tag,
            child_arcs, child_tags, pos_tags, word_mask, mrc_mask)
        # todo fix child bug
        loss = parent_arc_nll + parent_tag_nll  # + child_arc_loss + child_tag_loss
        eval_mask = self._get_mask_for_eval(mask=word_mask, pos_tags=pos_tags)
        bsz = span_idx.size(0)
        # [bsz]
        batch_range_vector = get_range_vector(bsz, get_device_of(span_idx))
        gold_positions = span_idx[:, 0]
        eval_mask = eval_mask[batch_range_vector, gold_positions]  # [bsz]
        if phase == "train" or not self.args.use_mst:
            # [bsz]
            pred_positions = parent_probs.argmax(1)
            metric_name = f"{phase}_stat"
            metric = getattr(self, metric_name)
                pred_positions.unsqueeze(-1),  # [bsz, 1]
                parent_tag_probs[batch_range_vector, pred_positions].argmax(
                    1).unsqueeze(-1),  # [bsz, 1]
                gold_positions.unsqueeze(-1),  # [bsz, 1]
                span_tag.unsqueeze(-1),  # [bsz, 1]
                eval_mask.unsqueeze(-1)  # [bsz, 1]
            metric = getattr(self, f"{phase}_stat")
            metric.update(meta_data["ann_idx"], meta_data["word_idx"],
                          [len(x) for x in meta_data["words"]], parent_probs,
                          parent_tag_probs, child_probs, child_tag_probs,
                          span_idx, span_tag, eval_mask)

        self.log(f'{phase}_loss', loss)
        return loss
    def concat_features(self, emb_z, token_indices, span_len):
        concatenate two features
            emb_z (batch_size, sentence_len, featsdim) : contextualized word representations
            token_indices: Dict[str, LongTensor], indices of different fields
            span_len: a number (from 0)
        batch_size = emb_z.size(0)
        sent_len = emb_z.size(1)
        hidden_dim = emb_z.size(2)
        emb_z = emb_z.unsqueeze(1).expand(batch_size, 1, sent_len, hidden_dim)

        span_exprs = [
            emb_z[:, :, i:i + span_len + 1] for i in range(sent_len - span_len)
        span_exprs =, 1)

        endpoint_vec = (span_exprs[:, :, 0] -
                        span_exprs[:, :, span_len]).unsqueeze(2).expand(
                            batch_size, sent_len - span_len, span_len + 1,

        index = Variable(torch.LongTensor(range(span_len + 1))).cuda(
        index = self.index_embeds(index).unsqueeze(0).unsqueeze(0).expand(
            batch_size, sent_len - span_len, span_len + 1,

        BILOU_features = self.get_BILOU_features(token_indices, sent_len,

        new_emb =, BILOU_features, endpoint_vec, index),

        return new_emb.transpose(1, 2).contiguous()
def generate_embeddings_for_pooling(sequence_tensor, span_starts, span_ends):
    #(B, L, E), #(B, L), #(B, L)
    span_starts = span_starts.unsqueeze(-1)
    span_ends = (span_ends - 1).unsqueeze(-1)
    span_widths = span_ends - span_starts
    max_batch_span_width = span_widths.max().item() + 1

    # Shape: (1, 1, max_batch_span_width)
    max_span_range_indices = util.get_range_vector(
        util.get_device_of(sequence_tensor)).view(1, 1, -1)
    # Shape: (batch_size, num_spans, max_batch_span_width)
    # This is a broadcasted comparison - for each span we are considering,
    # we are creating a range vector of size max_span_width, but masking values
    # which are greater than the actual length of the span.
    # We're using <= here (and for the mask below) because the span ends are
    # inclusive, so we want to include indices which are equal to span_widths rather
    # than using it as a non-inclusive upper bound.
    span_mask = (max_span_range_indices <= span_widths).float()
    raw_span_indices = span_ends - max_span_range_indices
    # We also don't want to include span indices which are less than zero,
    # which happens because some spans near the beginning of the sequence
    # have an end index < max_batch_span_width, so we add this to the mask here.
    span_mask = span_mask * (raw_span_indices >= 0).float()
    span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()

    # Shape: (batch_size * num_spans * max_batch_span_width)
    flat_span_indices = util.flatten_and_batch_shift_indices(
        span_indices, sequence_tensor.size(1))

    # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
    span_embeddings = util.batched_index_select(sequence_tensor, span_indices,

    return span_embeddings, span_mask
 def device_id(self):
    def _create_grammar_state(
            self, possible_actions: List[ProductionRule]) -> GrammarStatelet:
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        device = util.get_device_of(self._action_embedder.weight)
        # TODO(Mark): This type is pure \(- . ^)/
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            List[int]]]] = {}

        actions_grouped_by_nonterminal: Dict[str, List[Tuple[
            ProductionRule, int]]] = defaultdict(list)
        for i, action in enumerate(possible_actions):
            if action.rule == "":
            if action.is_global_rule:
                    (action, i))
                raise ValueError(
                    "The sql parser doesn't support non-global actions yet.")

        for key, production_rule_arrays in actions_grouped_by_nonterminal.items(
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            global_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                    (production_rule_array.rule_id, action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor =,
                if device >= 0:
                    global_action_tensor =

                global_input_embeddings = self._action_embedder(
                global_output_embeddings = self._output_action_embedder(

                translated_valid_actions[key]['global'] = (
                    global_input_embeddings, global_output_embeddings,
        return GrammarStatelet(['statement'],
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            p1_answer_marker: torch.IntTensor = None,
            p2_answer_marker: torch.IntTensor = None,
            p3_answer_marker: torch.IntTensor = None,
            yesno_list: torch.IntTensor = None,
            followup_list: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        batch_size, max_qa_count, max_q_len, _ = question[
        total_qa_count = batch_size * max_qa_count
        qa_mask =, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question,
        embedded_question = embedded_question.reshape(
            total_qa_count, max_q_len,
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question,
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(
            1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(
            total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(
                max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(
                1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(
                batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(
                total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(
            embedded_question =
                [embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \
                view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count,
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage =
                [repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(
                    total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage =
                    [repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(
                        total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(
                    repeated_embedded_passage =
                        [repeated_embedded_passage, p3_answer_marker_emb],

            repeated_encoded_passage = self._variational_dropout(
            encoded_passage = self._variational_dropout(
                self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(
                1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(
                total_qa_count, passage_length, self._encoding_dim)

        encoded_question = self._variational_dropout(
            self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)

        question_passage_similarity = masked_similarity.max(
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(
            repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(total_qa_count, passage_length, self._encoding_dim)

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage =[
            repeated_encoded_passage, passage_question_vectors,
            repeated_encoded_passage * passage_question_vectors,
            repeated_encoded_passage * tiled_question_passage_vector

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(
        self_attention_matrix = self._self_attention(residual_layer,

        mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \
               * repeated_passage_mask.reshape(total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length,
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs,
        self_attention_vecs =[
            self_attention_vecs, residual_layer,
            residual_layer * self_attention_vecs
        residual_layer = F.relu(

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage,
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder(
  [final_merged_passage, start_rep], dim=-1),
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(

        span_start_logits = util.replace_masked_values(span_start_logits,
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits,

        best_span = self._get_best_span_yesno_followup(span_start_logits,

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits,
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end],
                                            -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(
            for i in range(0, total_qa_count):
                    max(span_end[i] * 3 + i * passage_length * 3, 0))
                    max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                    max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc =

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                    max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end =

            _yesno = span_yesno_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1),
            loss += nll_loss(F.log_softmax(_followup, dim=-1),

            _yesno = span_yesno_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['followup'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                predicted_span = tuple(best_span_cpu[i * max_qa_count +

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]

                best_span_string = passage_str[start_offset:end_offset]
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            refs = [answer_texts[z] for z in idxes]
                                    squad_eval.f1_score, best_span_string,
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                        f1_score = squad_eval.metric_max_over_ground_truths(
                            squad_eval.f1_score, best_span_string,
                self._official_f1(100 * f1_score)
        return output_dict
    def forward(self,
                input_ids: torch.LongTensor,
                offsets: torch.LongTensor = None,
                token_type_ids: torch.LongTensor = None) -> torch.Tensor:
        input_ids : ``torch.LongTensor``
            The (batch_size, ..., max_sequence_length) tensor of wordpiece ids.
        offsets : ``torch.LongTensor``, optional
            The BERT embeddings are one per wordpiece. However it's possible/likely
            you might want one per original token. In that case, ``offsets``
            represents the indices of the desired wordpiece for each original token.
            Depending on how your token indexer is configured, this could be the
            position of the last wordpiece for each token, or it could be the position
            of the first wordpiece for each token.

            For example, if you had the sentence "Definitely not", and if the corresponding
            wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
            would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
            If offsets are provided, the returned tensor will contain only the wordpiece
            embeddings at those positions, and (in particular) will contain one embedding
            per token. If offsets are not provided, the entire tensor of wordpiece embeddings
            will be returned.
        token_type_ids : ``torch.LongTensor``, optional
            If an input consists of two sentences (as in the BERT paper),
            tokens from the first sentence should have type 0 and tokens from
            the second sentence should have type 1.  If you don't provide this
            (the default BertIndexer doesn't) then it's assumed to be all 0s.
        # pylint: disable=arguments-differ
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        input_mask = (input_ids != 0).long()

        # input_ids may have extra dimensions, so we reshape down to 2-d
        # before calling the BERT model and then reshape back at the end.
        all_encoder_layers, _ = self.bert_model(
        if self._scalar_mix is not None:
            mix = self._scalar_mix(all_encoder_layers, input_mask)
            mix = all_encoder_layers[-1]

        # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)

        if offsets is None:
            # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim)
            return util.uncombine_initial_dims(mix, input_ids.size())
            # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
            offsets2d = util.combine_initial_dims(offsets)
            # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
            range_vector = util.get_range_vector(
                offsets2d.size(0), device=util.get_device_of(mix)).unsqueeze(1)
            # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
            selected_embeddings = mix[range_vector, offsets2d]

            return util.uncombine_initial_dims(selected_embeddings,
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                p1_answer_marker: torch.IntTensor = None,
                p2_answer_marker: torch.IntTensor = None,
                p3_answer_marker: torch.IntTensor = None,
                yesno_list: torch.IntTensor = None,
                followup_list: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask =, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question, num_wrapping_dims=1)
        embedded_question = embedded_question.reshape(total_qa_count, max_q_len,
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(self._text_field_embedder(passage))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(question_num_ind)
            embedded_question =[embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \
                view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length)
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage =[repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage =[repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker)
                    repeated_embedded_passage =[repeated_embedded_passage, p3_answer_marker_emb],

            repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage,
            encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count,

        encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,

        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count,

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage =[repeated_encoded_passage,
                                          repeated_encoded_passage * passage_question_vectors,
                                          repeated_encoded_passage * tiled_question_passage_vector],

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage,
        self_attention_matrix = self._self_attention(residual_layer, residual_layer)

        mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \
               * repeated_passage_mask.reshape(total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs, residual_layer)
        self_attention_vecs =[self_attention_vecs, residual_layer,
                                         residual_layer * self_attention_vecs],
        residual_layer = F.relu(self._merge_self_attention(self_attention_vecs))

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder([final_merged_passage, start_rep], dim=-1),
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1)

        span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7)
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits,
                                                       span_yesno_logits, span_followup_logits,

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1),
            self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask), span_end.view(-1), ignore_index=-1)
            self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end], -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc =

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end =

            _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1)
            loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask)
            self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['followup'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index])

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]

                best_span_string = passage_str[start_offset:end_offset]
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            refs = [answer_texts[z] for z in idxes]
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                        f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                self._official_f1(100 * f1_score)
        return output_dict
                        head_tag_representation: torch.Tensor,
                        child_tag_representation: torch.Tensor,
                        attended_arcs: torch.Tensor,
                        head_indices: torch.Tensor,
                        head_tags: torch.Tensor,
                        mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        Computes the arc and tag loss for a sequence given gold head indices and tags.

        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachements of a given word to all other words.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The indices of the heads for every word.
        head_tags : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The dependency labels of the heads for every word.
        mask : ``torch.Tensor``, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        arc_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc loss.
        tag_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc tag loss.
        float_mask = mask.float()
        batch_size, sequence_length, _ = attended_arcs.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        normalised_arc_logits = masked_log_softmax(attended_arcs,
                                                   mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)

        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices)
        normalised_head_tag_logits = masked_log_softmax(head_tag_logits,
                                                        mask.unsqueeze(-1)) * float_mask.unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs))
        child_index = timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long()
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index, head_indices]
        tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum() / valid_positions.float()
        tag_nll = -tag_loss.sum() / valid_positions.float()
        return arc_nll, tag_nll
    def forward(self,  # type: ignore
                text: Dict[str, torch.LongTensor],
                spans: torch.IntTensor,
                span_labels: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.

        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings =[endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(math.floor(self._spans_per_word * document_length))

        (top_span_embeddings, top_span_mask,
         top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings,
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans,

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings,

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores,
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings,
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(span_pair_embeddings,

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = {"top_spans": top_spans,
                       "antecedent_indices": valid_antecedent_indices,
                       "predicted_antecedents": predicted_antecedents}
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1),

            antecedent_labels = util.flattened_index_select(pruned_gold_labels,
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels,
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log()
            negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
    def forward(
        self,  # type: ignore
        text: Dict[str, torch.LongTensor],
        spans: torch.IntTensor,
        metadata: List[Dict[str, Any]],
        doc_span_offsets: torch.IntTensor,
        span_labels: torch.IntTensor = None,
        doc_truth_spans: torch.IntTensor = None,
        doc_spans_in_truth: torch.IntTensor = None,
        doc_relation_labels: torch.Tensor = None,
        truth_spans: List[Set[Tuple[int, int]]] = None,
        doc_ner_labels: torch.IntTensor = None,
    ) -> Dict[str, torch.Tensor]:  # add matrix from datareader
        # pylint: disable=arguments-differ
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.
        metadata : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        doc_ner_labels : ``torch.IntTensor``.
            A tensor of shape # TODO,
        doc_span_offsets : ``torch.IntTensor``.
            A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1),
        doc_truth_spans : ``torch.IntTensor``.
            A tensor of shape (batch_size, max_sentences, max_truth_spans, 1),
        doc_spans_in_truth : ``torch.IntTensor``.
            A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1),
        doc_relation_labels : ``torch.Tensor``.
            A tensor of shape (batch_size, max_sentences, max_truth_spans, max_truth_spans),

        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(

        batch_size = len(spans)
        document_length = text_embeddings.size(1)
        max_sentence_length = max(
            len(sentence) for document in metadata
            for sentence in document['doc_tokens'])
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, spans)
        # TODO features dropout
        # Shape: (batch_size, num_spans, embedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            text_embeddings, spans)

        # Shape: (batch_size, num_spans, embedding_size + 2 * encoding_dim + feature_size)
        span_embeddings =
            [endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * document_length))
        num_relex_spans_to_keep = int(
            math.floor(self._relex_spans_per_word * max_sentence_length))

        # Shapes:
        # (batch_size, num_spans_to_keep, span_dim),
        # (batch_size, num_spans_to_keep),
        # (batch_size, num_spans_to_keep),
        # (batch_size, num_spans_to_keep, 1)
        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = self._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        # Shape: (batch_size, num_spans_to_keep, 1)
        top_span_mask = top_span_mask.unsqueeze(-1)

        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices,

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(
            top_span_embeddings, valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(
            top_span_mention_scores, valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(
            top_span_embeddings, candidate_antecedent_embeddings,
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(
            span_pair_embeddings, top_span_mention_scores,
            candidate_antecedent_mention_scores, valid_antecedent_log_mask)

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = dict()

        output_dict["top_spans"] = top_spans
        output_dict["antecedent_indices"] = valid_antecedent_indices
        output_dict["predicted_antecedents"] = predicted_antecedents

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]

        # Shape: (,)
        loss = 0

        # Shape: (batch_size, max_sentences, max_spans)
        doc_span_mask = (doc_span_offsets[:, :, :, 0] >= 0).float()
        # Shape: (batch_size, max_sentences, num_spans, span_dim)
        doc_span_embeddings = util.batched_index_select(

        # Shapes:
        # (batch_size, max_sentences, num_relex_spans_to_keep, span_dim),
        # (batch_size, max_sentences, num_relex_spans_to_keep),
        # (batch_size, max_sentences, num_relex_spans_to_keep),
        # (batch_size, max_sentences, num_relex_spans_to_keep, 1)
        pruned = self._relex_mention_pruner(
        (top_relex_span_embeddings, top_relex_span_mask,
         top_relex_span_indices, top_relex_span_mention_scores) = pruned

        # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1)
        top_relex_span_mask = top_relex_span_mask.unsqueeze(-1)

        # Shape: (batch_size, max_sentences, max_spans_per_sentence, 2)  # TODO do we need for a mask?
        doc_spans = util.batched_index_select(

        # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 2)
        top_relex_spans = nd_batched_index_select(doc_spans,

        # Shapes:
        # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, 3 * span_dim),
        # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep).
         relex_span_pair_mask) = self._compute_relex_span_pair_embeddings(
             top_relex_span_embeddings, top_relex_span_mask.squeeze(-1))

        # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, num_relation_labels)
        relex_scores = self._compute_relex_scores(
            relex_span_pair_embeddings, top_relex_span_mention_scores)
        output_dict['relex_scores'] = relex_scores
        output_dict['top_relex_spans'] = top_relex_spans

        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices,
            antecedent_labels_ = util.flattened_index_select(
                pruned_gold_labels, valid_antecedent_indices).squeeze(-1)
            antecedent_labels = antecedent_labels_ + valid_antecedent_log_mask.long(

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability x to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(
                coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log(
            negative_marginal_log_likelihood = -util.logsumexp(
            negative_marginal_log_likelihood *= top_span_mask.squeeze(
            negative_marginal_log_likelihood = negative_marginal_log_likelihood.sum(

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices,
                                     predicted_antecedents, metadata)

            coref_loss = negative_marginal_log_likelihood
            output_dict['coref_loss'] = coref_loss
            loss += self._loss_coref_weight * coref_loss

        if doc_relations is not None:

            # The adjacency matrix for relation extraction is very sparse.
            # As it is not just sparse, but row/column sparse (only few
            # rows and columns are non-zero and in that case these rows/columns
            # are not sparse), we implemented our own matrix for the case.
            # Here we have indices of truth spans and mapping, using which
            # we map prediction matrix on truth matrix.
            # TODO Add teacher forcing support.

            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep),
            relative_indices = top_relex_span_indices
            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1),
            compressed_indices = nd_batched_padded_index_select(
                doc_spans_in_truth, relative_indices)

            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, max_truth_spans)
            gold_pruned_rows = nd_batched_padded_index_select(
            gold_pruned_rows = gold_pruned_rows.permute(0, 1, 3,

            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep)
            gold_pruned_matrices = nd_batched_padded_index_select(
                padding_value=0)  # pad with epsilon
            gold_pruned_matrices = gold_pruned_matrices.permute(
                0, 1, 3, 2).contiguous()

            # TODO log_mask relex score before passing
            relex_loss = nd_cross_entropy_with_logits(relex_scores,
            output_dict['relex_loss'] = relex_loss

            self._relex_mention_recall(top_relex_spans.view(batch_size, -1, 2),
            self._compute_relex_metrics(output_dict, doc_relations)

            loss += self._loss_relex_weight * relex_loss

        if doc_ner_labels is not None:
            # Shape: (batch_size, max_sentences, num_spans, num_ner_classes)
            ner_scores = self._ner_scorer(doc_span_embeddings)
            output_dict['ner_scores'] = ner_scores

            ner_loss = nd_cross_entropy_with_logits(ner_scores, doc_ner_labels,
            output_dict['ner_loss'] = ner_loss
            loss += self._loss_ner_weight * ner_loss

        if not isinstance(loss, int):  # If loss is not yet modified
            output_dict["loss"] = loss

        return output_dict
    def forward(self,
                inputs: torch.Tensor,
                offsets: torch.Tensor = None) -> torch.Tensor:
        inputs: ``torch.Tensor``, required
            A ``(batch_size, num_timesteps)`` tensor representing the byte-pair encodings
            for the current batch.
        offsets: ``torch.Tensor``, required
            A ``(batch_size, max_sequence_length)`` tensor representing the word offsets
            for the current batch.

            An embedding representation of the input sequence
            having shape ``(batch_size, sequence_length, embedding_dim)``

        batch_size, num_timesteps = inputs.size()

        # the transformer embedding consists of the byte pair embeddings,
        # the special embeddings and the position embeddings.
        # the position embeddings are always at least self._transformer.n_ctx,
        # but may be longer.
        # the transformer "vocab" consists of the actual vocab and the
        # positional encodings. Here we want the count of just the former.
        vocab_size = self._transformer.vocab_size - self._transformer.n_ctx

        # vocab_size, vocab_size + 1, ...
        positional_encodings = (
            get_range_vector(num_timesteps, device=get_device_of(inputs)) +

        # Combine the inputs with positional encodings
        batch_tensor = torch.stack(
             positional_encodings.expand(batch_size, num_timesteps)],
            dim=-1,  # (batch_size, num_timesteps)

        byte_pairs_mask = inputs != 0

        # Embeddings is num_output_layers x (batch_size, num_timesteps, embedding_dim)
        layer_activations = self._transformer(batch_tensor)

        # Output of scalar_mix is (batch_size, num_timesteps, embedding_dim)
        if self._top_layer_only:
            mix = layer_activations[-1]
            mix = self._scalar_mix(layer_activations, byte_pairs_mask)

        # These embeddings are one per byte-pair, but we want one per original _word_.
        # So we choose the embedding corresponding to the last byte pair for each word,
        # which is captured by the ``offsets`` input.
        if offsets is not None:
            range_vector = get_range_vector(
                batch_size, device=get_device_of(mix)).unsqueeze(1)
            last_byte_pair_embeddings = mix[range_vector, offsets]
            # allow to return all byte pairs by passing no offsets
            seq_len = (byte_pairs_mask > 0).long().sum(dim=1).max()
            last_byte_pair_embeddings = mix[:, :seq_len]

        return last_byte_pair_embeddings
    def forward(
            self,  # type: ignore
            label_indices: torch.LongTensor,
            token_representations: torch.FloatTensor = None,
            raw_tokens: List[List[str]] = None,
            labels: torch.LongTensor = None,
            **kwargs) -> Dict[str, torch.Tensor]:
        If ``token_representations`` is provided, ``tokens`` is not required. If
        ``token_representations`` is ``None``, then ``tokens`` is required.

        label_indices : torch.LongTensor
            A LongTensor of shape (batch_size, max_num_adpositions) with the tokens
            to predict a label for for each element (sentence) in the batch.
        token_representations : torch.FloatTensor, optional (default = None)
            A tensor of shape (batch_size, sequence_length, representation_dim) with
            the represenatation of the first token. If None, we use a contextualizer
            within this model to produce the token representation.
        raw_tokens : List[List[str]], optional (default = None)
            A batch of lists with the raw token strings. Used to compute
            token_representations, if either are None.
        labels : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels
            of shape ``(batch_size, num_label_indices)``.

        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_label_indices,
            num_classes)`` representing unnormalized log probabilities
            of the classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_label_indices,
            num_classes)`` representing a distribution of the tag classes.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimized.
        # Convert to LongTensor
        # TODO: add PR to ArrayField to preserve array types.
        label_indices = label_indices.long()
        if token_representations is None:
            if self._contextualizer is None:
                raise ConfigurationError(
                    "token_representation not provided as input to the model, and no "
                    "contextualizer was specified. Either add a contextualizer to your "
                    "dataset reader (preferred if your contextualizer is frozen) or to "
                    "this model (if you wish to train your contextualizer).")
            if raw_tokens is None:
                raise ValueError(
                    "Input raw_tokens is ``None`` --- make sure to set "
                    "include_raw_tokens in the DatasetReader to True.")
            if label_indices is None:
                raise ValueError("Did not recieve any token indices, needed "
                                 "if the contextualizer is within the model.")
            # Convert contextualizer output into a tensor
            # Shape: (batch_size, max_seq_len, representation_dim)
            token_representations, _ = pad_contextualizer_output(

        # Move token representation to the same device as the
        # module (CPU or CUDA). TODO(nfliu): This only works if the module
        # is on one device.
        device = next(self._decoder._linear_layers[0].parameters()).device
        token_representations =
        text_mask = get_text_mask_from_representations(token_representations)
        text_mask =
        label_mask = self._get_label_mask_from_label_indices(label_indices)
        label_mask =

        # Mask out the -1 padding in the label_indices, since that doesn't
        # work with indexing. Note that we can't 0 pad because 0 is actually
        # a valid label index, so we pad with -1 just for the purposes of
        # proper mask calculation and then convert to 0-padding by applying
        # the mask.
        label_indices = label_indices * label_mask

        # Encode the token representation.
        encoded_token_representations = self._encoder(token_representations,

        batch_size = label_indices.size(0)
        # Index into the encoded_token_representations to get tensors corresponding
        # to the representations of the tokens to predict labels for.
        # Shape: (batch_size, num_label_indices, representation_dim)
        range_vector = get_range_vector(
            batch_size, get_device_of(label_indices)).unsqueeze(1)
        selected_token_representations = encoded_token_representations[
            range_vector, label_indices]
        selected_token_representations = selected_token_representations.contiguous(

        # Decode out a label from the token representation
        # Shape: (batch_size, num_label_indices, num_classes)
        logits = self._decoder(selected_token_representations)
        class_probabilities = F.softmax(logits, dim=-1)
        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities
        if labels is not None:
            loss = sequence_cross_entropy_with_logits(
                logits, labels, label_mask, average=self.loss_average)
            for name, metric in self.metrics.items():
                # When not running in error analysis mode, skip
                # metrics that start with "_"
                if not self.error_analysis and name.startswith("_"):
                metric(logits, labels, label_mask.float())
            output_dict["loss"] = loss
        return output_dict
    def forward(
        self,  # type: ignore
        text: Dict[str, torch.LongTensor],
        spans: torch.IntTensor,
        span_labels: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ

        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(

        document_length = text_embeddings.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        if self._use_gold_mentions:
            if text_embeddings.is_cuda:
                device = torch.device("cuda")
                device = torch.device("cpu")

            s = [
                torch.as_tensor(pair, dtype=torch.long, device=device)
                for cluster in metadata[0]["clusters"] for pair in cluster
            gm = torch.stack(s, dim=0).unsqueeze(0).unsqueeze(1)

            span_mask = spans.unsqueeze(2) - gm
            span_mask = (span_mask[:, :, :, 0] == 0) + (span_mask[:, :, :, 1]
                                                        == 0)
            span_mask, _ = (span_mask == 2).max(-1)
            num_spans = span_mask.sum().item()
            span_mask = span_mask.float()
            span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
            num_spans = spans.size(1)
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings =
            [endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * document_length))

        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = self._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices,

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = self._generate_valid_antecedents(
            num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(
            top_span_embeddings, valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(
            top_span_mention_scores, valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(
            top_span_embeddings, candidate_antecedent_embeddings,
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(

        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        predicted_antecedents -= 1

        output_dict = {
            "top_spans": top_spans,
            "antecedent_indices": valid_antecedent_indices,
            "predicted_antecedents": predicted_antecedents,
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices,

            antecedent_labels = util.flattened_index_select(
                pruned_gold_labels, valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels)
            coreference_log_probs = util.last_dim_log_softmax(
                coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log(
            negative_marginal_log_likelihood = -util.logsumexp(

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices,
                                     predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
    def forward(self,
                input_ids: torch.LongTensor,
                offsets: torch.LongTensor = None,
                token_type_ids: torch.LongTensor = None) -> torch.Tensor:
        input_ids : ``torch.LongTensor``
            The (batch_size, ..., max_sequence_length) tensor of wordpiece ids.
        offsets : ``torch.LongTensor``, optional
            The BERT embeddings are one per wordpiece. However it's possible/likely
            you might want one per original token. In that case, ``offsets``
            represents the indices of the desired wordpiece for each original token.
            Depending on how your token indexer is configured, this could be the
            position of the last wordpiece for each token, or it could be the position
            of the first wordpiece for each token.

            For example, if you had the sentence "Definitely not", and if the corresponding
            wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
            would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
            If offsets are provided, the returned tensor will contain only the wordpiece
            embeddings at those positions, and (in particular) will contain one embedding
            per token. If offsets are not provided, the entire tensor of wordpiece embeddings
            will be returned.
        token_type_ids : ``torch.LongTensor``, optional
            If an input consists of two sentences (as in the BERT paper),
            tokens from the first sentence should have type 0 and tokens from
            the second sentence should have type 1.  If you don't provide this
            (the default BertIndexer doesn't) then it's assumed to be all 0s.
        # pylint: disable=arguments-differ
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        input_mask = (input_ids != 0).long()

        # input_ids may have extra dimensions, so we reshape down to 2-d
        # before calling the BERT model and then reshape back at the end.
        all_encoder_layers, _ = self.bert_model(input_ids=util.combine_initial_dims(input_ids),
        if self._scalar_mix is not None:
            mix = self._scalar_mix(all_encoder_layers, input_mask)
            mix = all_encoder_layers[-1]

        # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)

        if offsets is None:
            # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim)
            return util.uncombine_initial_dims(mix, input_ids.size())
            # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
            offsets2d = util.combine_initial_dims(offsets)
            # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
            range_vector = util.get_range_vector(offsets2d.size(0),
            # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
            selected_embeddings = mix[range_vector, offsets2d]

            return util.uncombine_initial_dims(selected_embeddings, offsets.size())
    def _distance_pruning(
        top_span_embeddings: torch.FloatTensor,
        top_span_mention_scores: torch.FloatTensor,
        max_antecedents: int,
    ) -> Tuple[torch.FloatTensor, torch.BoolTensor, torch.LongTensor, torch.LongTensor]:
        Generates antecedents for each span and prunes down to `max_antecedents`. This method
        prunes antecedents only based on distance (i.e. number of intervening spans). The closest
        antecedents are kept.

        # Parameters

        top_span_embeddings: torch.FloatTensor, required.
            The embeddings of the top spans.
            (batch_size, num_spans_to_keep, embedding_size).
        top_span_mention_scores: torch.FloatTensor, required.
            The mention scores of the top spans.
            (batch_size, num_spans_to_keep).
        max_antecedents: int, required.
            The maximum number of antecedents to keep for each span.

        # Returns

        top_partial_coreference_scores: torch.FloatTensor
            The partial antecedent scores for each span-antecedent pair. Computed by summing
            the span mentions scores of the span and the antecedent. This score is partial because
            compared to the full coreference scores, it lacks the interaction term
            w * FFNN([g_i, g_j, g_i * g_j, features]).
            (batch_size, num_spans_to_keep, max_antecedents)
        top_antecedent_mask: torch.BoolTensor
            The mask representing whether each antecedent span is valid. Required since
            different spans have different numbers of valid antecedents. For example, the first
            span in the document should have no valid antecedents.
            (batch_size, num_spans_to_keep, max_antecedents)
        top_antecedent_offsets: torch.LongTensor
            The distance between the span and each of its antecedents in terms of the number
            of considered spans (i.e not the word distance between the spans).
            (batch_size, num_spans_to_keep, max_antecedents)
        top_antecedent_indices: torch.LongTensor
            The indices of every antecedent to consider with respect to the top k spans.
            (batch_size, num_spans_to_keep, max_antecedents)
        # These antecedent matrices are independent of the batch dimension - they're just a function
        # of the span's position in top_spans.
        # The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        num_spans_to_keep = top_span_embeddings.size(1)
        device = util.get_device_of(top_span_embeddings)

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        ) = self._generate_valid_antecedents(  # noqa
            num_spans_to_keep, max_antecedents, device

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        top_antecedent_mention_scores = util.flattened_index_select(
            top_span_mention_scores.unsqueeze(-1), top_antecedent_indices

        # Shape: (batch_size, num_spans_to_keep, max_antecedents) * 4
        top_partial_coreference_scores = (
            top_span_mention_scores.unsqueeze(-1) + top_antecedent_mention_scores
        top_antecedent_indices = top_antecedent_indices.unsqueeze(0).expand_as(
        top_antecedent_offsets = top_antecedent_offsets.unsqueeze(0).expand_as(
        top_antecedent_mask = top_antecedent_mask.expand_as(top_partial_coreference_scores)

        return (
    def forward(
            passage_attention: torch.Tensor,
            passage_lengths: List[int],
            count_answer: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        device_id = allenutil.get_device_of(passage_attention)

        batch_size, max_passage_length = passage_attention.size()

        # Shape: (B, passage_length)
        passage_mask = (passage_attention >= 0).float()

        # List of (B, P) shaped tensors
        scaled_attentions = [
            passage_attention * sf for sf in self.scaling_vals
        # Shape: (B, passage_length, num_scaling_factors)
        scaled_passage_attentions = torch.stack(scaled_attentions, dim=2)

        # Shape (batch_size, 1)
        passage_len_bias = self.passagelength_to_bias(
            passage_mask.sum(1, keepdim=True))

        scaled_passage_attentions = scaled_passage_attentions * passage_mask.unsqueeze(

        # Shape: (B, passage_length, hidden_dim)
        count_hidden_repr = self.passage_attention_to_count(
            scaled_passage_attentions, passage_mask)

        # Shape: (B, passage_length, 1) -- score for each token
        passage_span_logits = self.passage_count_hidden2logits(
        # Shape: (B, passage_length) -- sigmoid on token-score
        token_sigmoids = torch.sigmoid(passage_span_logits.squeeze(2))
        token_sigmoids = token_sigmoids * passage_mask

        # Shape: (B, 1) -- sum of sigmoids. This will act as the predicted mean
        # passage_count_mean = torch.sum(token_sigmoids, dim=1, keepdim=True) + passage_len_bias
        passage_count_mean = torch.sum(token_sigmoids, dim=1, keepdim=True)

        # Shape: (1, count_vals)
        self.countvals = allenutil.get_range_vector(
            10, device=device_id).unsqueeze(0).float()

        variance = 0.2

        # Shape: (batch_size, count_vals)
        l2_by_vsquared = torch.pow(self.countvals - passage_count_mean,
                                   2) / (2 * variance * variance)
        exp_val = torch.exp(-1 * l2_by_vsquared) + 1e-30
        # Shape: (batch_size, count_vals)
        count_distribution = exp_val / (torch.sum(exp_val, 1, keepdim=True))

        # Loss computation
        output_dict = {}
        loss = 0.0
        pred_count_idx = torch.argmax(count_distribution, 1)
        if count_answer is not None:
            # L2-loss
            passage_count_mean = passage_count_mean.squeeze(1)
            L2Loss = F.mse_loss(input=passage_count_mean,
            loss = L2Loss
            predictions = passage_count_mean.detach().cpu().numpy()
            predictions = np.round_(predictions)

            gold_count = count_answer.detach().cpu().numpy()
            correct_vec = (predictions == gold_count)
            correct_perc = sum(correct_vec) / batch_size
            # print(f"{correct_perc} {predictions} {gold_count}")

            # loss = F.cross_entropy(input=count_distribution, target=count_answer)
            # List of predicted count idxs, Shape: (B,)
            # correct_vec = (pred_count_idx == count_answer).float()
            # correct_perc = torch.sum(correct_vec) / batch_size
            # self.count_acc(correct_perc.item())

        batch_loss = loss / batch_size
        output_dict["loss"] = batch_loss
        output_dict["passage_attention"] = passage_attention
        output_dict["passage_sigmoid"] = token_sigmoids
        output_dict["count_mean"] = passage_count_mean
        output_dict["count_distritbuion"] = count_distribution
        output_dict["count_answer"] = count_answer
        output_dict["pred_count"] = pred_count_idx

        return output_dict
    def _construct_loss(self,
                        head_tag_representation: torch.Tensor,
                        child_tag_representation: torch.Tensor,
                        attended_arcs: torch.Tensor,
                        head_indices: torch.Tensor,
                        head_tags: torch.Tensor,
                        mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        Computes the arc and tag loss for a sequence given gold head indices and tags.

        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The indices of the heads for every word.
        head_tags : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length).
            The dependency labels of the heads for every word.
        mask : ``torch.Tensor``, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        arc_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc loss.
        tag_nll : ``torch.Tensor``, required.
            The negative log likelihood from the arc tag loss.
        float_mask = mask.float()
        batch_size, sequence_length, _ = attended_arcs.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        normalised_arc_logits = masked_log_softmax(attended_arcs,
                                                   mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)

        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices)
        normalised_head_tag_logits = masked_log_softmax(head_tag_logits,
                                                        mask.unsqueeze(-1)) * float_mask.unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs))
        child_index = timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long()
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index, head_indices]
        tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum() / valid_positions.float()
        tag_nll = -tag_loss.sum() / valid_positions.float()
        return arc_nll, tag_nll
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
		question : Dict[str, torch.LongTensor]
			From a ``TextField``.
		passage : Dict[str, torch.LongTensor]
			From a ``TextField``.  The model assumes that this passage contains the answer to the
			question, and predicts the beginning and ending positions of the answer within the
		span_start : ``torch.IntTensor``, optional
			From an ``IndexField``.  This is one of the things we are trying to predict - the
			beginning position of the answer with the passage.  This is an `inclusive` token index.
			If this is given, we will compute a loss that gets included in the output dictionary.
		span_end : ``torch.IntTensor``, optional
			From an ``IndexField``.  This is one of the things we are trying to predict - the
			ending position of the answer with the passage.  This is an `inclusive` token index.
			If this is given, we will compute a loss that gets included in the output dictionary.
		metadata : ``List[Dict[str, Any]]``, optional
			If present, this should contain the question ID, original passage text, and token
			offsets into the passage for each instance in the batch.  We use this for computing
			official metrics using the official SQuAD evaluation script.  The length of this list
			should be the batch size, and each dictionary should have the keys ``id``,
			``original_passage``, and ``token_offsets``.  If you only want the best span string and
			don't care about official metrics, you can omit the ``id`` key.

		An output dictionary consisting of:
		span_start_logits : torch.FloatTensor
			A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
			probabilities of the span start position.
		span_start_probs : torch.FloatTensor
			The result of ``softmax(span_start_logits)``.
		span_end_logits : torch.FloatTensor
			A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
			probabilities of the span end position (inclusive).
		span_end_probs : torch.FloatTensor
			The result of ``softmax(span_end_logits)``.
		best_span : torch.IntTensor
			The result of a constrained inference over ``span_start_logits`` and
			``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
			and each offset is a token index.
		loss : torch.FloatTensor, optional
			A scalar loss to be optimised.
		best_span_str : List[str]
			If sufficient metadata was provided for the instances in the batch, we also return the
			string from the original passage that the model thinks is the best answer to the
        embedded_question = self._highway_layer(
        embedded_passage = self._highway_layer(
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask, dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # Shape: (batch_size, passage_length)
        question_passage_similarity = torch.transpose(
            passage_question_similarity, 1, 2)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask, dim=-1)
        # Shape: (batch_size, question_length, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)

        passage_gate = torch.unsqueeze(
                                              passage_question_vectors), -1)
        passage_fusion = self._passage_fusion_function(
            encoded_passage, passage_question_vectors)
        gated_passage = passage_gate * passage_fusion + (
            1 - passage_gate) * encoded_passage

        question_gate = torch.unsqueeze(
                                               question_passage_vector), -1)
        question_fusion = self._question_fusion_function(
            encoded_question, question_passage_vector)
        gated_question = question_gate * question_fusion + (
            1 - question_gate) * encoded_question

        passage_passage_similarity = self._self_matrix_attention(
            gated_passage, gated_passage)
        passage_passage_attention = util.masked_softmax(
            passage_passage_similarity, passage_mask, dim=-1)
        passage_passage_vector = util.weighted_sum(gated_passage,
        final_passage = self._fusion_function(gated_passage,

        modeled_passage = self._dropout(
            self._passage_modeling_layer(final_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        span_logits = self._span_predictor(modeled_passage)

        modeled_question = self._question_modeling_layer(
            gated_question, question_lstm_mask)
        question_vector = self._question_encoding_layer(
            modeled_question, question_lstm_mask).unsqueeze(-1)
        span_start_logits = torch.bmm(self._span_start_weight(modeled_passage),
        span_end_logits = torch.bmm(self._span_end_weight(modeled_passage),
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,

        # Compute the loss for training.
        if span_start is not None:
            device_id = util.get_device_of(span_start)
            weight = self._span_weight.cuda(
                device_id) if device_id >= 0 else self._span_weight
            arange_mask = util.get_range_vector(passage_length,
            span_mask = (arange_mask >= span_start) & (arange_mask <= span_end)
            span_loss = nll_loss(self._masked_log_softmax(
                span_logits, passage_mask).transpose(1, 2),
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss + span_loss / 2

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
    def create_cached_cnn_embeddings(self, tokens: List[str]) -> None:
        Given a list of tokens, this method precomputes word representations
        by running just the character convolutions and highway layers of elmo,
        essentially creating uncontextual word vectors. On subsequent forward passes,
        the word ids are looked up from an embedding, rather than being computed on
        the fly via the CNN encoder.

        This function sets 3 attributes:

        _word_embedding : `torch.Tensor`
            The word embedding for each word in the tokens passed to this method.
        _bos_embedding : `torch.Tensor`
            The embedding for the BOS token.
        _eos_embedding : `torch.Tensor`
            The embedding for the EOS token.

        # Parameters

        tokens : `List[str]`, required.
            A list of tokens to precompute character convolutions for.
        tokens = [
            ELMoCharacterMapper.bos_token, ELMoCharacterMapper.eos_token
        ] + tokens
        timesteps = 32
        batch_size = 32
        chunked_tokens = lazy_groups_of(iter(tokens), timesteps)

        all_embeddings = []
        device = get_device_of(next(self.parameters()))
        for batch in lazy_groups_of(chunked_tokens, batch_size):
            # Shape (batch_size, timesteps, 50)
            batched_tensor = batch_to_ids(batch)
            # NOTE: This device check is for when a user calls this method having
            # already placed the model on a device. If this is called in the
            # constructor, it will probably happen on the CPU. This isn't too bad,
            # because it's only a few convolutions and will likely be very fast.
            if device >= 0:
                batched_tensor = batched_tensor.cuda(device)
            output = self._token_embedder(batched_tensor)
            token_embedding = output["token_embedding"]
            mask = output["mask"]
            token_embedding, _ = remove_sentence_boundaries(
                token_embedding, mask)
                token_embedding.view(-1, token_embedding.size(-1)))
        full_embedding =, 0)

        # We might have some trailing embeddings from padding in the batch, so
        # we clip the embedding and lookup to the right size.
        full_embedding = full_embedding[:len(tokens), :]
        embedding = full_embedding[2:len(tokens), :]
        vocab_size, embedding_dim = list(embedding.size())

        from allennlp.modules.token_embedders import Embedding  # type: ignore

        self._bos_embedding = full_embedding[0, :]
        self._eos_embedding = full_embedding[1, :]
        self._word_embedding = Embedding(  # type: ignore
    def forward(self,
                spans_tensor: torch.FloatTensor,
                spans_mask: torch.FloatTensor,
                question_tensor: torch.FloatTensor,
                question_mask: torch.FloatTensor,
                evd_chain_labels: torch.FloatTensor,
                self_att_layer: Seq2SeqEncoder,
                sent_encoder: Seq2SeqEncoder,
                transition_mask: torch.IntTensor = None,
                start_transition_mask: torch.FloatTensor = None,
                get_all_beam: bool=False):

        #print("spans_tensor", spans_tensor.shape)
        #print("spans_mask", spans_mask.shape)
        batch_size, num_spans, max_batch_span_width = spans_mask.size()
        # Shape: (batch_size, num_spans, embedding_dim)
        max_pooled_span_emb = spans_tensor[:, :, 0, :]

        # self attention on spans representation
        # shape: (batch_size, num_spans, embedding_dim)
        #max_pooled_span_emb = max_pooled_span_emb.view(batch_size, num_spans, spans_tensor.size(2))
        # shape: (batch_size, num_spans)
        max_pooled_span_mask = (torch.sum(spans_mask, dim=-1) >= 1).float()
        # shape: (batch_size, num_spans, embedding_dim)
        max_pooled_span_emb = sent_encoder(max_pooled_span_emb, max_pooled_span_mask)
        # shape: (batch_size, num_spans, embedding_dim)
        att_max_pooled_span_emb, _, att_score = self_att_layer(max_pooled_span_emb, max_pooled_span_mask)
        max_pooled_span_emb = max_pooled_span_emb + att_max_pooled_span_emb
        att_score = None

        # extract the final hidden states as the question vector
        # Shape: (batch_size, embedding_dim)
        #question_emb = util.get_final_encoder_states(question_tensor, question_mask, True)
        question_emb = question_tensor[:, 0, :]

        # decode the most likely evidence path
        # shape (all_predictions): (batch_size, K, num_decoding_steps)
        # shape (all_logprobs): (batch_size, K, num_decoding_steps)
        # shape (seq_logprobs): (batch_size, K)
        # shape (final_hidden): (batch_size, K, decoder_output_dim)
        #print("max_pooled_span_emb", max_pooled_span_emb.shape)
        #print("max_pooled_span_mask", max_pooled_span_mask.shape)
        print("start trans mask:", start_transition_mask)
        print("trans mask:", transition_mask)
        all_predictions, all_logprobs, seq_logprobs, final_hidden = self.evd_decoder(max_pooled_span_emb,
        print("all prediction:", all_predictions)

        # The selection order of each sentence. Set to -1 if not being chosen
        # shape: (batch_size, K, num_spans)
        _, beam, num_steps = all_predictions.size()
        orders = spans_tensor.new_ones((batch_size, beam, 1+num_spans)) * -1
        indices = util.get_range_vector(num_steps, util.get_device_of(spans_tensor)).\
                expand(batch_size, beam, num_steps)
        orders.scatter_(2, all_predictions, indices)
        orders = orders[:, :, 1:]

        # For beamsearch, get the top one. For other helpers, just like squeeze
        if not get_all_beam:
            all_predictions = all_predictions[:, 0, :]
            all_logprobs = all_logprobs[:, 0, :]
            seq_logprobs = seq_logprobs[:, 0]
            final_hidden = final_hidden[:, 0, :]

        # build the gate. The dim is set to 1 + num_spans to account for the end embedding
        # shape: (batch_size, 1+num_spans) or (batch_size, K, 1+num_spans)
        if not get_all_beam:
            gate = spans_tensor.new_zeros((batch_size, 1+num_spans))
            gate = spans_tensor.new_zeros((batch_size, beam, 1+num_spans))
        gate.scatter_(-1, all_predictions, 1.)
        # remove the column for end embedding
        # shape: (batch_size, num_spans) or (batch_size, K, num_spans)
        gate = gate[..., 1:]
        #print("gate:", gate)
        #print("real num:", torch.sum(gate, dim=1))
        #print("seq probs:", torch.exp(seq_logprobs))

        # shape: (batch_size * num_spans, 1) or (batch_size * K * num_spans, 1)
        if not get_all_beam:
            gate = gate.reshape(batch_size * num_spans, 1)
            gate = gate.reshape(batch_size * beam * num_spans, 1)

        # The probability of each selected sentence being selected. If not selected, set to 0.
        # shape: (batch_size * num_spans, 1) or (batch_size * K * num_spans, 1)
        if not get_all_beam:
            gate_probs = spans_tensor.new_zeros((batch_size, 1+num_spans))
            gate_probs = spans_tensor.new_zeros((batch_size, beam, 1+num_spans))
        gate_probs.scatter_(-1, all_predictions, all_logprobs.exp())
        gate_probs = gate_probs[..., 1:]
        if not get_all_beam:
            gate_probs = gate_probs.reshape(batch_size * num_spans, 1)
            gate_probs = gate_probs.reshape(batch_size * beam * num_spans, 1)

        return all_predictions, all_logprobs, seq_logprobs, gate, gate_probs, max_pooled_span_mask, att_score, orders
    def forward(self,
                input_ids: torch.LongTensor,
                offsets: torch.LongTensor = None,
                token_type_ids: torch.LongTensor = None) -> torch.Tensor:
        input_ids : ``torch.LongTensor``
            The (batch_size, ..., max_sequence_length) tensor of wordpiece ids.
        offsets : ``torch.LongTensor``, optional
            The BERT embeddings are one per wordpiece. However it's possible/likely
            you might want one per original token. In that case, ``offsets``
            represents the indices of the desired wordpiece for each original token.
            Depending on how your token indexer is configured, this could be the
            position of the last wordpiece for each token, or it could be the position
            of the first wordpiece for each token.
            For example, if you had the sentence "Definitely not", and if the corresponding
            wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
            would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
            If offsets are provided, the returned tensor will contain only the wordpiece
            embeddings at those positions, and (in particular) will contain one embedding
            per token. If offsets are not provided, the entire tensor of wordpiece embeddings
            will be returned.
        token_type_ids : ``torch.LongTensor``, optional
            If an input consists of two sentences (as in the BERT paper),
            tokens from the first sentence should have type 0 and tokens from
            the second sentence should have type 1.  If you don't provide this
            (the default BertIndexer doesn't) then it's assumed to be all 0s.
        # pylint: disable=arguments-differ
        batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1)
        initial_dims = list(input_ids.shape[:-1])

        # The embedder may receive an input tensor that has a sequence length longer than can
        # be fit. In that case, we should expect the wordpiece indexer to create padded windows
        # of length `self.max_pieces` for us, and have them concatenated into one long sequence.
        # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..."
        # We can then split the sequence into sub-sequences of that length, and concatenate them
        # along the batch dimension so we effectively have one huge batch of partial sentences.
        # This can then be fed into BERT without any sentence length issues. Keep in mind
        # that the memory consumption can dramatically increase for large batches with extremely
        # long sentences.
        needs_split = full_seq_len > self.max_pieces
        last_window_size = 0
        if needs_split:
            # Split the flattened list by the window size, `max_pieces`
            split_input_ids = list(input_ids.split(self.max_pieces, dim=-1))

            # We want all sequences to be the same length, so pad the last sequence
            last_window_size = split_input_ids[-1].size(-1)
            padding_amount = self.max_pieces - last_window_size
            split_input_ids[-1] = F.pad(split_input_ids[-1],
                                        pad=[0, padding_amount],

            # Now combine the sequences along the batch dimension
            input_ids =, dim=0)

        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        input_mask = (input_ids != 0).long()

        # input_ids may have extra dimensions, so we reshape down to 2-d
        # before calling the BERT model and then reshape back at the end.
        all_encoder_layers, _ = self.bert_model(
        all_encoder_layers = torch.stack(all_encoder_layers)

        if needs_split:
            # First, unpack the output embeddings into one long sequence again
            unpacked_embeddings = torch.split(all_encoder_layers,
            unpacked_embeddings =, dim=2)

            # Next, select indices of the sequence such that it will result in embeddings representing the original
            # sentence. To capture maximal context, the indices will be the middle part of each embedded window
            # sub-sequence (plus any leftover start and final edge windows), e.g.,
            #  0     1 2    3  4   5    6    7     8     9   10   11   12    13 14  15
            # "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]"
            # with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start
            # and final windows with indices [0, 1] and [14, 15] respectively.

            # Find the stride as half the max pieces, ignoring the special start and end tokens
            # Calculate an offset to extract the centermost embeddings of each window
            stride = (self.max_pieces - self.start_tokens -
                      self.end_tokens) // 2
            stride_offset = stride // 2 + self.start_tokens

            first_window = list(range(stride_offset))

            max_context_windows = [
                i for i in range(full_seq_len) if stride_offset - 1 < i %
                self.max_pieces < stride_offset + stride

            final_window_start = full_seq_len - (
                full_seq_len % self.max_pieces) + stride_offset + stride
            final_window = list(range(final_window_start, full_seq_len))

            select_indices = first_window + max_context_windows + final_window


            recombined_embeddings = unpacked_embeddings[:, :, select_indices]
            recombined_embeddings = all_encoder_layers

        # Recombine the outputs of all layers
        # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim)
        # recombined =, dim=2)
        input_mask = (recombined_embeddings != 0).long()

        # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)

        if offsets is None:
            # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim)
            dims = initial_dims if needs_split else input_ids.size()
            layers = util.uncombine_initial_dims(recombined_embeddings, dims)
            # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
            offsets2d = util.combine_initial_dims(offsets)
            # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
            range_vector = util.get_range_vector(
            # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
            selected_embeddings = recombined_embeddings[:, range_vector,

            layers = util.uncombine_initial_dims(selected_embeddings,

        if self._scalar_mix is not None:
            return self._scalar_mix(layers, input_mask)
        elif self.combine_layers == "last":
            return layers[-1]
            return layers
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # shape (batch_size, sequence_length, 1)
        global_attention_logits = self._global_attention(sequence_tensor)

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(max_batch_span_width,
                                                       util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices)

        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_logits = util.batched_index_select(global_attention_logits,
        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_weights = util.masked_softmax(span_attention_logits, span_mask)

        # Do a weighted sum of the embedded spans with
        # respect to the normalised attention distributions.
        # Shape: (batch_size, num_spans, embedding_dim)
        attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights)

        if span_indices_mask is not None:
            # Above we were masking the widths of spans with respect to the max
            # span width in the batch. Here we are masking the spans which were
            # originally passed in as padding.
            return attended_text_embeddings * span_indices_mask.unsqueeze(-1).float()

        return attended_text_embeddings
    def _coarse_to_fine_pruning(
        top_span_embeddings: torch.FloatTensor,
        top_span_mention_scores: torch.FloatTensor,
        top_span_mask: torch.BoolTensor,
        max_antecedents: int,
    ) -> Tuple[torch.FloatTensor, torch.BoolTensor, torch.LongTensor, torch.LongTensor]:
        Generates antecedents for each span and prunes down to `max_antecedents`. This method
        prunes antecedents using a fast bilinar interaction score between a span and a candidate
        antecedent, and the highest-scoring antecedents are kept.

        # Parameters

        top_span_embeddings: torch.FloatTensor, required.
            The embeddings of the top spans.
            (batch_size, num_spans_to_keep, embedding_size).
        top_span_mention_scores: torch.FloatTensor, required.
            The mention scores of the top spans.
            (batch_size, num_spans_to_keep).
        top_span_mask: torch.BoolTensor, required.
            The mask for the top spans.
            (batch_size, num_spans_to_keep).
        max_antecedents: int, required.
            The maximum number of antecedents to keep for each span.

        # Returns

        top_partial_coreference_scores: torch.FloatTensor
            The partial antecedent scores for each span-antecedent pair. Computed by summing
            the span mentions scores of the span and the antecedent as well as a bilinear
            interaction term. This score is partial because compared to the full coreference scores,
            it lacks the interaction term
            w * FFNN([g_i, g_j, g_i * g_j, features]).
            (batch_size, num_spans_to_keep, max_antecedents)
        top_antecedent_mask: torch.BoolTensor
            The mask representing whether each antecedent span is valid. Required since
            different spans have different numbers of valid antecedents. For example, the first
            span in the document should have no valid antecedents.
            (batch_size, num_spans_to_keep, max_antecedents)
        top_antecedent_offsets: torch.LongTensor
            The distance between the span and each of its antecedents in terms of the number
            of considered spans (i.e not the word distance between the spans).
            (batch_size, num_spans_to_keep, max_antecedents)
        top_antecedent_indices: torch.LongTensor
            The indices of every antecedent to consider with respect to the top k spans.
            (batch_size, num_spans_to_keep, max_antecedents)
        batch_size, num_spans_to_keep = top_span_embeddings.size()[:2]
        device = util.get_device_of(top_span_embeddings)

        # Shape: (1, num_spans_to_keep, num_spans_to_keep)
        _, _, valid_antecedent_mask = self._generate_valid_antecedents(
            num_spans_to_keep, num_spans_to_keep, device

        mention_one_score = top_span_mention_scores.unsqueeze(1)
        mention_two_score = top_span_mention_scores.unsqueeze(2)
        bilinear_weights = self._coarse2fine_scorer(top_span_embeddings).transpose(1, 2)
        bilinear_score = torch.matmul(top_span_embeddings, bilinear_weights)
        # Shape: (batch_size, num_spans_to_keep, num_spans_to_keep); broadcast op
        partial_antecedent_scores = mention_one_score + mention_two_score + bilinear_score

        # Shape: (batch_size, num_spans_to_keep, num_spans_to_keep); broadcast op
        span_pair_mask = top_span_mask.unsqueeze(-1) & valid_antecedent_mask

        # Shape:
        # (batch_size, num_spans_to_keep, max_antecedents) * 3
        ) = util.masked_topk(partial_antecedent_scores, span_pair_mask, max_antecedents)

        top_span_range = util.get_range_vector(num_spans_to_keep, device)
        # Shape: (num_spans_to_keep, num_spans_to_keep); broadcast op
        valid_antecedent_offsets = top_span_range.unsqueeze(-1) - top_span_range.unsqueeze(0)

        # TODO: we need to make `batched_index_select` more general to make this less awkward.
        top_antecedent_offsets = util.batched_index_select(
            .expand(batch_size, num_spans_to_keep, num_spans_to_keep)
            .reshape(batch_size * num_spans_to_keep, num_spans_to_keep, 1),
            top_antecedent_indices.view(-1, max_antecedents),
        ).reshape(batch_size, num_spans_to_keep, max_antecedents)

        return (
    def inference_coref(self, batch, embedded_text_input_relation, mask):
        submodel = self.model._tagger_coref

        ### Fast inference of coreference ###
        spans = batch["spans"]

        document_length = mask.size(1)
        num_spans = spans.size(1)

        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        spans = F.relu(spans.float()).long()

        encoded_text_coref = submodel._context_layer(
            embedded_text_input_relation, mask)
        endpoint_span_embeddings = submodel._endpoint_span_extractor(
            encoded_text_coref, spans)
        attended_span_embeddings = submodel._attentive_span_extractor(
            embedded_text_input_relation, spans)

        span_embeddings =
            [endpoint_span_embeddings, attended_span_embeddings], -1)
        num_spans_to_keep = int(
            math.floor(submodel._spans_per_word * document_length))

        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = submodel._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)
        top_spans = util.batched_index_select(spans, top_span_indices,

        max_antecedents = min(submodel._max_antecedents, num_spans_to_keep)

        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = submodel._generate_valid_antecedents(
            num_spans_to_keep, max_antecedents, util.get_device_of(mask))
        candidate_antecedent_embeddings = util.flattened_index_select(
            top_span_embeddings, valid_antecedent_indices)

        candidate_antecedent_mention_scores = util.flattened_index_select(
            top_span_mention_scores, valid_antecedent_indices).squeeze(-1)
        span_pair_embeddings = submodel._compute_span_pair_embeddings(
            top_span_embeddings, candidate_antecedent_embeddings,
        coreference_scores = submodel._compute_coreference_scores(

        _, predicted_antecedents = coreference_scores.max(2)
        predicted_antecedents -= 1

        output_dict = {
            "top_spans": top_spans,
            "antecedent_indices": valid_antecedent_indices,
            "predicted_antecedents": predicted_antecedents,

        return output_dict