Example #1
0
class TransformerHead(LoggedModule):
    def __init__(self, config, v_dim, l_dim, loc_dim, backbone):
        super(TransformerHead, self).__init__()
        self.config = config.MODEL.MMSS_HEAD.TRANSFORMER
        self.v_dim = v_dim
        self.l_dim = l_dim
        self.loc_dim = loc_dim
        self.backbone = backbone

        self.mvm_loss = self.config.MVM_LOSS
        self.mmm_loss = self.config.MMM_LOSS
        self.num_negative = self.config.MVM_LOSS_NUM_NEGATIVE

        self.bert_config = BertConfig(**self.config.BERT_CONFIG)
        self.v2l_projection = nn.Linear(self.v_dim, self.l_dim)
        self.visual_emb = VisualEmbedding(self.bert_config, self.l_dim, self.loc_dim)
        self.encoder = BertEncoder(self.bert_config)
        self.pooler = BertPooler(self.bert_config)
        self.heads = MMPreTrainingHeads(self.bert_config, self.v_dim)

        self.encoder.apply(self._init_weights)
        self.pooler.apply(self._init_weights)
        self.heads.apply(self._init_weights)

        self._tie_weights()

        self.loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
        if self.mvm_loss == 'reconstruction_error':
            self.vis_criterion = nn.MSELoss(reduction="none")
        elif self.mvm_loss == 'contrastive_cross_entropy':
            self.vis_criterion = nn.CrossEntropyLoss()
        elif self.mvm_loss == '':
            self.vis_criterion = None
            for p in self.heads.imagePredictions.parameters():
                p.requires_grad = False
        else:
            raise NotImplementedError

        if self.mmm_loss == '':
            for p in self.pooler.parameters():
                p.requires_grad = False
            for p in self.heads.bi_seq_relationship.parameters():
                p.requires_grad = False

    def _tie_weights(self):
        assert(self.heads.predictions.decoder.weight.shape[0] == 
               self.backbone.embeddings.shape[0])
        assert(self.heads.predictions.decoder.weight.shape[1] == 
               self.backbone.embeddings.shape[1])
        self.heads.predictions.decoder.weight = self.backbone.embeddings

    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range)
        elif isinstance(module, BertLayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def forward(self, input_image, input_caption):
        '''
        Mapping between my terminology and the vilbert codebase:
        input_ids => Not needed. Instead I have caption_emb which is the output of BERT
        image_feat => region_features
        image_loc => region_loc
        token_type_ids => Not needed. All zero anyway
        attention_mask => caption_mask,
        image_attention_mask => region mask,
        masked_lm_labels => combination of target_caption_ids, mlm_mask,
        image_label => mvm_mask,
        image_target => target_region_features,
        next_sentence_label => Not needed for now (image_caption_match_label).
        '''

        caption_emb = input_caption['encoded_tokens']
        caption_mask = input_caption['attention_mask']
        mlm_mask = input_caption['mlm_mask']
        target_caption_ids = input_caption['target_ids']

        region_features = input_image['region_features']
        region_mask = input_image['region_mask']
        region_loc = input_image['region_loc']
        mvm_mask = input_image['mvm_mask']
        target_region_features = input_image['target_region_features']

        target_caption_ids = torch.where(
            mlm_mask > 0,
            target_caption_ids,
            torch.ones_like(target_caption_ids) * (-1)
        )
        caption_mask = caption_mask.to(torch.float32)
        region_mask = region_mask.to(torch.float32)
        mlm_mask = mlm_mask.to(torch.float32)
        mvm_mask = mvm_mask.to(torch.float32)
        num_words = caption_mask.sum(dim=1)
        _, max_num_words = caption_mask.shape
        batch_size, max_num_regions, _ = region_features.shape

        image_emb = self.v2l_projection(region_features)
        image_emb = self.visual_emb(image_emb, region_loc)

        if self.mmm_loss == 'cross_entropy':
            image_emb = image_emb[None, :, :, :].repeat(batch_size, 1, 1, 1).reshape(
                batch_size**2, max_num_regions, self.l_dim)
            caption_emb = caption_emb[:, None, :, :].repeat(1, batch_size, 1, 1).reshape(
                batch_size**2, max_num_words, self.l_dim)
            region_mask = region_mask[None, :, :].repeat(batch_size, 1, 1).reshape(
                batch_size**2, max_num_regions)
            caption_mask = caption_mask[:, None, :].repeat(1, batch_size, 1).reshape(
                batch_size**2, max_num_words)

        embedded_tokens = torch.cat([caption_emb, image_emb], dim=1)
        attention_mask = torch.cat([caption_mask, region_mask], dim=1)

        sequence_output, = self.encoder(
            embedded_tokens,
            attention_mask[:, None, None, :],
            head_mask=[None] * self.bert_config.num_hidden_layers,
            output_attentions=False,
            output_hidden_states=False,
        )
        pooled_output = self.pooler(sequence_output)
        sequence_output_t, sequence_output_v = torch.split(
            sequence_output,
            [max_num_words, max_num_regions],
            dim=1
        )
        prediction_scores_t, prediction_scores_v, seq_relationship_score = self.heads(
            sequence_output_t, sequence_output_v, pooled_output
        )
        # prediction_scores_v = prediction_scores_v[:, 1:]

        if self.mmm_loss == 'cross_entropy':
            prediction_scores_t = torch.diagonal(prediction_scores_t.reshape(
                batch_size, batch_size, max_num_words, self.bert_config.vocab_size),
                dim1=0, dim2=1).permute(2, 0, 1)
            prediction_scores_v = torch.diagonal(prediction_scores_v.reshape(
                batch_size, batch_size, max_num_regions, self.v_dim),
                dim1=0, dim2=1).permute(2, 0, 1)

        masked_lm_loss = self.loss_fct(
            prediction_scores_t.reshape(-1, self.bert_config.vocab_size),
            target_caption_ids.reshape(-1),
        )

        if self.mmm_loss == 'binary':
            raise NotImplementedError
            next_sentence_loss = self.loss_fct(
                seq_relationship_score.view(-1, 2), image_caption_match_label.view(-1)
            )
        elif self.mmm_loss == 'cross_entropy':
            global_dist = seq_relationship_score[:, 0]
            pw_cost = global_dist.reshape(batch_size, batch_size)
            pw_logits_c_cap = torch.log_softmax(- pw_cost, dim=0)
            pw_logits_c_img = torch.log_softmax(- pw_cost, dim=1)
            next_sentence_loss_c_cap = torch.diag(- pw_logits_c_cap).mean()
            next_sentence_loss_c_img = torch.diag(- pw_logits_c_img).mean()
            next_sentence_loss = next_sentence_loss_c_cap + next_sentence_loss_c_img
        elif self.mmm_loss == '':
            next_sentence_loss = torch.tensor(0.0).cuda()
        else:
            raise NotImplementedError

        if self.mvm_loss == 'reconstruction_error':
            raise NotImplementedError
            img_loss = self.vis_criterion(prediction_scores_v, target_region_features)
            masked_img_loss = torch.sum(
                img_loss * (mvm_mask == 1).unsqueeze(2).float()
            ) / max(
                torch.sum((mvm_mask == 1).unsqueeze(2).expand_as(img_loss)), 1
            )
        elif self.mvm_loss == 'contrastive_cross_entropy':
            raise NotImplementedError
            # generate negative sampled index.
            num_negative = self.num_negative
            num_across_batch = int(self.num_negative * 0.7)
            num_inside_batch = int(self.num_negative * 0.3)

            # random negative across batches.
            row_across_index = target_caption_ids.new(
                batch_size, max_num_regions, num_across_batch
            ).random_(0, batch_size - 1)
            col_across_index = target_caption_ids.new(
                batch_size, max_num_regions, num_across_batch
            ).random_(0, max_num_regions)

            for i in range(batch_size - 1):
                row_across_index[i][row_across_index[i] == i] = batch_size - 1
            final_across_index = row_across_index * max_num_regions + col_across_index

            # random negative inside batches.
            row_inside_index = target_caption_ids.new(
                batch_size, max_num_regions, num_inside_batch
            ).zero_()
            col_inside_index = target_caption_ids.new(
                batch_size, max_num_regions, num_inside_batch
            ).random_(0, max_num_regions - 1)

            for i in range(batch_size):
                row_inside_index[i] = i
            for i in range(max_num_regions - 1):
                col_inside_index[:, i, :][col_inside_index[:, i, :] == i] = (
                    max_num_regions - 1
                )
            final_inside_index = row_inside_index * max_num_regions + col_inside_index

            final_index = torch.cat((final_across_index, final_inside_index), dim=2)

            # Let's first sample where we need to compute.
            predict_v = prediction_scores_v[mvm_mask == 1]
            neg_index_v = final_index[mvm_mask == 1]

            flat_image_target = target_region_features.view(batch_size * max_num_regions, -1)
            # we also need to append the target feature at the begining.
            negative_v = flat_image_target[neg_index_v]
            positive_v = target_region_features[mvm_mask == 1]
            sample_v = torch.cat((positive_v.unsqueeze(1), negative_v), dim=1)

            # calculate the loss.
            score = torch.bmm(sample_v, predict_v.unsqueeze(2)).squeeze(2)
            masked_img_loss = self.vis_criterion(
                score, target_caption_ids.new(score.size(0)).zero_()
            )
        elif self.mvm_loss == '':
            masked_img_loss = torch.tensor(0.0).cuda()
        else:
            raise NotImplementedError
        # masked_img_loss = torch.sum(img_loss) / (img_loss.shape[0] * img_loss.shape[1])

        losses = {
            'Masked Language Modeling Loss': masked_lm_loss,
            'Masked Visual Modeling Loss': masked_img_loss,
            'Image Caption Matching Loss': next_sentence_loss,
        }
        acc_num = (prediction_scores_t.argmax(dim=-1) == target_caption_ids).to(torch.float32).sum()
        acc_denom = (target_caption_ids >= 0).to(torch.float32).sum()
        acc = torch.where(acc_denom > 0, acc_num / acc_denom, acc_denom)
        other_info = {
            'Masked Language Modeling Accuracy': acc,
        }
        if self.mmm_loss == 'cross_entropy':
            other_info['Batch Accuracy (Choose Caption)'] = torch.mean(
                (pw_cost.argmin(dim=0) ==torch.arange(batch_size).to('cuda')).to(torch.float32))
            other_info['Batch Accuracy (Choose Image)'] = torch.mean(
                (pw_cost.argmin(dim=1) == torch.arange(batch_size).to('cuda')).to(torch.float32))

        self.log_dict(losses)
        self.log_dict(other_info)

        return other_info, losses
Example #2
0
class BiaffineDependencyS2SQueryParser(torch.nn.Module):
    """
    This dependency parser follows the model of
    [Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016)]
    (https://arxiv.org/abs/1611.01734) .
    But We use token-to-token MRC to extract parent and labels

    Args:
        config: S2SConfig that defines model dim and structure
        bert_dir: pretrained bert directory
    """

    # def __init__(self, config: Union[BertMrcS2SDependencyConfig, RobertaMrcS2SDependencyConfig]):
    def __init__(self, config: S2SConfig, bert_dir: str = ""):
        # super().__init__(config)
        super().__init__()

        self.config = config

        num_dep_labels = len(config.dep_tags)
        num_pos_labels = len(config.pos_tags)
        hidden_size = config.additional_layer_dim

        if config.pos_dim > 0:
            self.pos_embedding = nn.Embedding(num_pos_labels, config.pos_dim)
            nn.init.xavier_uniform_(self.pos_embedding.weight)
            if config.additional_layer_type != "lstm" and config.pos_dim + config.bert_config.hidden_size != hidden_size:
                self.fuse_layer = nn.Linear(
                    config.pos_dim + config.bert_config.hidden_size,
                    hidden_size)
                nn.init.xavier_uniform_(self.fuse_layer.weight)
                self.fuse_layer.bias.data.zero_()
            else:
                self.fuse_layer = None
        else:
            self.pos_embedding = None

        # if isinstance(config, BertMrcS2SDependencyConfig):
        #     self.bert = BertModel(config)
        #     self.arch = "bert"
        # else:
        #     self.roberta = RobertaModel(config)
        #     self.arch = "roberta"
        self.bert = AutoModel.from_pretrained(
            pretrained_model_name_or_path=bert_dir, config=config.bert_config)

        if config.additional_layer > 0:
            if config.additional_layer_type == "transformer":
                new_config = deepcopy(config.bert_config)
                new_config.hidden_size = hidden_size
                new_config.num_hidden_layers = config.additional_layer
                new_config.hidden_dropout_prob = new_config.attention_probs_dropout_prob = config.mrc_dropout
                # new_config.attention_probs_dropout_prob = config.biaf_dropout  # todo add to hparams and tune
                self.additional_encoder = BertEncoder(new_config)
                self.additional_encoder.apply(self._init_bert_weights)
            else:
                assert hidden_size % 2 == 0, "Bi-LSTM need an even hidden_size"
                self.additional_encoder = StackedBidirectionalLstmSeq2SeqEncoder(
                    input_size=config.pos_dim + config.bert_config.hidden_size,
                    hidden_size=hidden_size // 2,
                    num_layers=config.additional_layer,
                    recurrent_dropout_probability=config.mrc_dropout,
                    use_highway=True)

        else:
            self.additional_encoder = None

        # todo use MLP
        self.parent_feedforward = nn.Linear(hidden_size, 1)
        self.parent_start_feedforward = nn.Linear(hidden_size, 1)
        self.parent_end_feedforward = nn.Linear(hidden_size, 1)
        self.parent_tag_feedforward = nn.Linear(hidden_size, num_dep_labels)

        if config.predict_child:
            self.child_feedforward = nn.Linear(hidden_size, 1)
            self.child_start_feedforward = nn.Linear(hidden_size, 1)
            self.child_end_feedforward = nn.Linear(hidden_size, 1)

        # self._dropout = nn.Dropout(config.mrc_dropout)
        self._dropout = InputVariationalDropout(config.mrc_dropout)

        # init linear children
        for layer in self.children():
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                if layer.bias is not None:
                    layer.bias.data.zero_()

    def _init_bert_weights(self, module):
        """ Initialize the weights. copy from transformers.BertPreTrainedModel"""
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(
                mean=0.0, std=self.config.bert_config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    @overrides
    def forward(
        self,  # type: ignore
        token_ids: torch.LongTensor,
        type_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        pos_tags: torch.LongTensor,
        word_mask: torch.BoolTensor,
        parent_mask: torch.BoolTensor,
        parent_start_mask: torch.BoolTensor,
        parent_end_mask: torch.BoolTensor,
        child_mask: torch.BoolTensor = None,
        parent_idxs: torch.LongTensor = None,
        parent_tags: torch.LongTensor = None,
        parent_starts: torch.BoolTensor = None,
        parent_ends: torch.BoolTensor = None,
        child_idxs: torch.BoolTensor = None,
        child_starts: torch.BoolTensor = None,
        child_ends: torch.BoolTensor = None,
    ):
        """  todo implement docstring
        Args:
            token_ids: [batch_size, num_word_pieces]
            type_ids: [batch_size, num_word_pieces]
            offsets: [batch_size, num_words, 2]
            wordpiece_mask: [batch_size, num_word_pieces]
            pos_tags: [batch_size, num_words]
            word_mask: [batch_size, num_words]
            parent_mask: [batch_size, num_words]
            parent_start_mask: [batch_size, num_words]
            parent_end_mask: [batch_size, num_words]
            child_mask: [batch_size, num_words]
            parent_idxs: [batch_size]
            parent_tags: [batch_size]
            parent_starts: [batch_size]
            parent_ends: [batch_size]
            child_idxs: [batch_size, num_words]
            child_starts: [batch_size, num_words]
            child_ends: [batch_size, num_words]
        Returns:
            parent_probs: [batch_size, num_words]
            parent_tag_probs: [batch_size, num_words, num_tags]
            parent_start_probs: [batch_size, num_words]
            parent_end_probs: [batch_size, num_words]
            child_probs: [batch_size, num_words]
            child_start_probs: [batch_size, num_words]
            child_end_probs: [batch_size, num_words]
            arc_loss (if parent_idx is not None)
            tag_loss (if parent_idxs and parent_tags are not None)
            start_loss (if parent_starts is not None)
            end_loss (if parent_ends is not None)
            child_loss (if child_idxs is not None)
            child_start_loss (if child_starts is not None)
            child_end_loss (if child_ends is not None)
        """

        cls_embedding, embedded_text_input = self.get_word_embedding(
            token_ids=token_ids,
            offsets=offsets,
            wordpiece_mask=wordpiece_mask,
            type_ids=type_ids,
        )
        if self.pos_embedding is not None:
            embedded_pos_tags = self.pos_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
            if self.fuse_layer is not None:
                embedded_text_input = self.fuse_layer(embedded_text_input)
        # todo compare normal dropout with InputVariationalDropout
        embedded_text_input = self._dropout(embedded_text_input)

        if self.additional_encoder is not None:
            if self.config.additional_layer_type == "transformer":
                # bert = self.bert if self.arch == "bert" else self.roberta
                extended_attention_mask = self.bert.get_extended_attention_mask(
                    word_mask, word_mask.size(), word_mask.device)
                encoded_text = self.additional_encoder(
                    hidden_states=embedded_text_input,
                    attention_mask=extended_attention_mask)[0]
            else:
                encoded_text = self.additional_encoder(
                    inputs=embedded_text_input, mask=word_mask)
        else:
            encoded_text = embedded_text_input

        batch_size, seq_len, encoding_dim = encoded_text.size()

        # shape (batch_size, sequence_length, tag_classes)
        parent_tag_scores = self.parent_tag_feedforward(encoded_text)
        # shape (batch_size, sequence_length)
        parent_scores = self.parent_feedforward(encoded_text).squeeze(-1)
        parent_start_scores = self.parent_start_feedforward(
            encoded_text).squeeze(-1)
        parent_end_scores = self.parent_end_feedforward(encoded_text).squeeze(
            -1)

        # mask out impossible positions
        minus_inf = -1e8
        parent_mask = torch.logical_and(parent_mask, word_mask)
        parent_scores = parent_scores + (~parent_mask).float() * minus_inf
        parent_start_mask = torch.logical_and(parent_start_mask, word_mask)
        parent_start_scores = parent_start_scores + (
            ~parent_start_mask).float() * minus_inf
        parent_end_mask = torch.logical_and(parent_end_mask, word_mask)
        parent_end_scores = parent_end_scores + (
            ~parent_end_mask).float() * minus_inf

        parent_probs = F.softmax(parent_scores, dim=-1)
        parent_start_probs = F.softmax(parent_start_scores, dim=-1)
        parent_end_probs = F.softmax(parent_end_scores, dim=-1)
        parent_tag_probs = F.softmax(parent_tag_scores, dim=-1)

        output = (parent_probs, parent_tag_probs, parent_start_probs,
                  parent_end_probs)

        if self.config.predict_child:
            child_scores = self.child_feedforward(encoded_text).squeeze(-1)
            child_start_scores = self.child_start_feedforward(
                encoded_text).squeeze(-1)
            child_end_scores = self.child_end_feedforward(
                encoded_text).squeeze(-1)
            # todo add child mask - child should be inside the origin span
            if child_mask is None:
                child_mask = torch.ones_like(word_mask)
            else:
                child_mask = torch.logical_and(child_mask, word_mask)
            child_scores = child_scores + (~child_mask).float() * minus_inf
            child_start_scores = child_start_scores + (
                ~child_mask).float() * minus_inf
            child_end_scores = child_end_scores + (
                ~child_mask).float() * minus_inf
            child_probs = torch.sigmoid(child_scores)
            child_start_probs = torch.sigmoid(child_start_scores)
            child_end_probs = torch.sigmoid(child_end_scores)
            output = output + (child_probs, child_start_probs, child_end_probs)

        # add losses
        batch_range_vector = get_range_vector(
            batch_size, get_device_of(encoded_text))  # [bsz]
        if parent_idxs is not None:
            # [bsz, seq_len]
            parent_logits = F.log_softmax(parent_scores, dim=-1)
            parent_arc_nll = -parent_logits[batch_range_vector, parent_idxs]
            parent_arc_nll = parent_arc_nll.mean()
            output = output + (parent_arc_nll, )

            if parent_tags is not None:
                parent_tag_nll = F.cross_entropy(
                    parent_tag_scores[batch_range_vector, parent_idxs],
                    parent_tags)
                output = output + (parent_tag_nll, )

        if parent_starts is not None:
            # [bsz, seq_len]
            parent_start_logits = F.log_softmax(parent_start_scores, dim=-1)
            parent_start_nll = -parent_start_logits[batch_range_vector,
                                                    parent_starts].mean()
            output = output + (parent_start_nll, )

        if parent_ends is not None:
            # [bsz, seq_len]
            parent_end_logits = F.log_softmax(parent_end_scores, dim=-1)
            parent_end_nll = -parent_end_logits[batch_range_vector,
                                                parent_ends].mean()
            output = output + (parent_end_nll, )

        if self.config.predict_child:
            if child_idxs is not None:
                child_loss = F.binary_cross_entropy_with_logits(
                    child_scores, child_idxs.float(), reduction="none")
                child_loss = (child_loss *
                              child_mask).sum() / (child_mask.sum() + 1e-8)
                output = output + (child_loss, )
            if child_starts is not None:
                child_start_loss = F.binary_cross_entropy_with_logits(
                    child_start_scores, child_starts.float(), reduction="none")
                child_start_loss = (child_start_loss * child_mask).sum() / (
                    child_mask.sum() + 1e-8)
                output = output + (child_start_loss, )
            if child_ends is not None:
                child_end_loss = F.binary_cross_entropy_with_logits(
                    child_end_scores, child_ends.float(), reduction="none")
                child_end_loss = (child_end_loss *
                                  child_mask).sum() / (child_mask.sum() + 1e-8)
                output = output + (child_end_loss, )

        return output

    def get_word_embedding(
        self,
        token_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        type_ids: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:  # type: ignore
        """get [CLS] embedding and word-level embedding"""
        # Shape: [batch_size, num_wordpieces, embedding_size].
        # embed_model = self.bert if self.arch == "bert" else self.roberta
        # embeddings = embed_model(token_ids, token_type_ids=type_ids, attention_mask=wordpiece_mask)[0]
        embeddings = self.bert(token_ids,
                               token_type_ids=type_ids,
                               attention_mask=wordpiece_mask)[0]

        # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size)
        # span_mask: (batch_size, num_orig_tokens, max_span_length)
        span_embeddings, span_mask = allennlp_util.batched_span_select(
            embeddings, offsets)
        span_mask = span_mask.unsqueeze(-1)
        span_embeddings *= span_mask  # zero out paddings

        span_embeddings_sum = span_embeddings.sum(2)
        span_embeddings_len = span_mask.sum(2)
        # Shape: (batch_size, num_orig_tokens, embedding_size)
        orig_embeddings = span_embeddings_sum / torch.clamp_min(
            span_embeddings_len, 1)

        # All the places where the span length is zero, write in zeros.
        orig_embeddings[(span_embeddings_len == 0).expand(
            orig_embeddings.shape)] = 0

        return embeddings[:, 0, :], orig_embeddings
class BiaffineDependencyT2TParser(nn.Module):
    """
    This dependency parser follows the model of
    [Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016)]
    (https://arxiv.org/abs/1611.01734) .
    But We use token-to-token MRC to extract parent and labels
    """
    def __init__(self, bert_dir, config):
        super().__init__()

        self.config = config

        num_dep_labels = len(config.dep_tags)
        num_pos_labels = len(config.pos_tags)
        hidden_size = config.additional_layer_dim

        if config.pos_dim > 0:
            self.pos_embedding = nn.Embedding(num_pos_labels, config.pos_dim)
            nn.init.xavier_uniform_(self.pos_embedding.weight)
            if config.additional_layer_type != "lstm" and config.pos_dim + config.hidden_size != hidden_size:
                self.fuse_layer = nn.Linear(
                    config.pos_dim + config.hidden_size, hidden_size)
                nn.init.xavier_uniform_(self.fuse_layer.weight)
                self.fuse_layer.bias.data.zero_()
            else:
                self.fuse_layer = None
        else:
            self.pos_embedding = None

        if isinstance(config, BertMrcT2TDependencyConfig):
            self.bert = BertModel.from_pretrained(bert_dir, config=self.config)
        elif isinstance(config, RobertaMrcT2TDependencyConfig):
            self.bert = RobertaModel.from_pretrained(bert_dir,
                                                     config=self.config)

        if config.additional_layer > 0:
            if config.additional_layer_type == "transformer":
                new_config = deepcopy(config)
                new_config.hidden_size = hidden_size
                new_config.num_hidden_layers = config.additional_layer
                new_config.hidden_dropout_prob = new_config.attention_probs_dropout_prob = config.mrc_dropout
                # new_config.attention_probs_dropout_prob = config.biaf_dropout  # todo add to hparams and tune
                self.additional_encoder = BertEncoder(new_config)
                self.additional_encoder.apply(self._init_bert_weights)
            else:
                assert hidden_size % 2 == 0, "Bi-LSTM need an even hidden_size"
                self.additional_encoder = StackedBidirectionalLstmSeq2SeqEncoder(
                    input_size=config.pos_dim + config.hidden_size,
                    hidden_size=hidden_size // 2,
                    num_layers=config.additional_layer,
                    recurrent_dropout_probability=config.mrc_dropout,
                    use_highway=True)

        else:
            self.additional_encoder = None

        self.parent_feedforward = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.GELU(),
            InputVariationalDropout(config.mrc_dropout),
            nn.Linear(hidden_size // 2, 1),
        )

        self.parent_tag_feedforward = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.GELU(),
            InputVariationalDropout(config.mrc_dropout),
            nn.Linear(hidden_size // 2, num_dep_labels),
        )

        self.child_feedforward = deepcopy(self.parent_feedforward)
        self.child_tag_feedforward = deepcopy(self.parent_tag_feedforward)

        # self.mrc_dropout = nn.Dropout(config.mrc_dropout)
        self._dropout = InputVariationalDropout(config.mrc_dropout)
        # todo renit feedforward?

    def _init_bert_weights(self, module):
        """ Initialize the weights. copy from transformers.BertPreTrainedModel"""
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0,
                                       std=self.config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    @overrides
    def forward(
        self,  # type: ignore
        token_ids: torch.LongTensor,
        type_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        span_idx: torch.LongTensor,
        span_tag: torch.LongTensor,
        child_arcs: torch.LongTensor,
        child_tags: torch.LongTensor,
        pos_tags: torch.LongTensor,
        word_mask: torch.BoolTensor,
        mrc_mask: torch.BoolTensor,
    ):
        """  todo implement docstring
        Args:
            token_ids: [batch_size, num_word_pieces]
            type_ids: [batch_size, num_word_pieces]
            offsets: [batch_size, num_words, 2]
            wordpiece_mask: [batch_size, num_word_pieces]
            span_idx: [batch_size, 2]
            span_tag: [batch_size, 1]
            child_arcs: [batch_size, num_words]
            child_tags: [batch_size, num_words]
            pos_tags: [batch_size, num_words]
            word_mask: [batch_size, num_words]
            mrc_mask: [batch_size, num_words]
        Returns:
            parent_probs: [batch_size, num_word]
            parent_tag_probs: [batch_size, num_words]
            arc_nll: [1]
            tag_nll: [1]
        """

        embedded_text_input = self.get_word_embedding(
            token_ids=token_ids,
            offsets=offsets,
            wordpiece_mask=wordpiece_mask,
            type_ids=type_ids,
        )
        if self.pos_embedding is not None:
            embedded_pos_tags = self.pos_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
            if self.fuse_layer is not None:
                embedded_text_input = self.fuse_layer(embedded_text_input)
        # todo compare normal dropout with InputVariationalDropout
        embedded_text_input = self._dropout(embedded_text_input)

        if self.additional_encoder is not None:
            if self.config.additional_layer_type == "transformer":
                extended_attention_mask = self.bert.get_extended_attention_mask(
                    word_mask, word_mask.size(), word_mask.device)
                encoded_text = self.additional_encoder(
                    hidden_states=embedded_text_input,
                    attention_mask=extended_attention_mask)[0]
            else:
                encoded_text = self.additional_encoder(
                    inputs=embedded_text_input, mask=word_mask)
        else:
            encoded_text = embedded_text_input

        batch_size, seq_len, encoding_dim = encoded_text.size()

        # shape (batch_size, sequence_length, tag_classes)
        parent_tag_scores = self.parent_tag_feedforward(encoded_text)
        # shape (batch_size, sequence_length)
        parent_scores = self.parent_feedforward(encoded_text).squeeze(-1)

        # [bsz, seq_len, tag_classes]
        child_tag_scores = self.child_tag_feedforward(encoded_text)
        # [bsz, seq_len]
        child_scores = self.child_feedforward(encoded_text).squeeze(-1)

        # todo support cases that span_idx and span_tag are None
        # [bsz]
        batch_range_vector = get_range_vector(batch_size,
                                              get_device_of(encoded_text))
        # [bsz]
        gold_positions = span_idx[:, 0]

        # compute parent arc loss
        minus_inf = -1e8
        mrc_mask = torch.logical_and(mrc_mask, word_mask)
        parent_scores = parent_scores + (~mrc_mask).float() * minus_inf
        child_scores = child_scores + (~mrc_mask).float() * minus_inf

        # [bsz, seq_len]
        parent_logits = F.log_softmax(parent_scores, dim=-1)
        parent_arc_nll = -parent_logits[batch_range_vector,
                                        gold_positions].mean()

        # compute parent tag loss
        parent_tag_nll = F.cross_entropy(
            parent_tag_scores[batch_range_vector, gold_positions], span_tag)

        parent_probs = F.softmax(parent_scores, dim=-1)
        parent_tag_probs = F.softmax(parent_tag_scores, dim=-1)
        child_probs = F.sigmoid(child_scores)
        child_tag_probs = F.softmax(child_tag_scores, dim=-1)

        child_arc_loss = F.binary_cross_entropy_with_logits(child_scores,
                                                            child_arcs.float(),
                                                            reduction="none")
        child_arc_loss = (child_arc_loss *
                          mrc_mask.float()).sum() / mrc_mask.float().sum()
        child_tag_loss = F.cross_entropy(child_tag_scores.view(
            batch_size * seq_len, -1),
                                         child_tags.view(-1),
                                         reduction="none")
        child_tag_loss = (child_tag_loss * child_arcs.float().view(-1)
                          ).sum() / (child_arcs.float().sum() + 1e-8)

        return parent_probs, parent_tag_probs, child_probs, child_tag_probs, parent_arc_nll, parent_tag_nll, child_arc_loss, child_tag_loss

    def get_word_embedding(
        self,
        token_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        type_ids: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:  # type: ignore
        """get word-level embedding"""
        # Shape: [batch_size, num_wordpieces, embedding_size].
        embeddings = self.bert(token_ids,
                               token_type_ids=type_ids,
                               attention_mask=wordpiece_mask)[0]

        # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size)
        # span_mask: (batch_size, num_orig_tokens, max_span_length)
        span_embeddings, span_mask = allennlp_util.batched_span_select(
            embeddings, offsets)
        span_mask = span_mask.unsqueeze(-1)
        span_embeddings *= span_mask  # zero out paddings

        span_embeddings_sum = span_embeddings.sum(2)
        span_embeddings_len = span_mask.sum(2)
        # Shape: (batch_size, num_orig_tokens, embedding_size)
        orig_embeddings = span_embeddings_sum / torch.clamp_min(
            span_embeddings_len, 1)

        # All the places where the span length is zero, write in zeros.
        orig_embeddings[(span_embeddings_len == 0).expand(
            orig_embeddings.shape)] = 0

        return orig_embeddings
class SpanProposal(torch.nn.Module):
    """
    This model is used to extract candidate start/end subtree span rooted at each token.

    Args:
        config: SpanProposal Config that defines model dim and structure
        bert_dir: pretrained bert directory

    """
    def __init__(self, config: SpanProposalConfig, bert_dir: str = ""):
        super().__init__()

        self.config = config

        num_pos_labels = len(config.pos_tags)
        hidden_size = config.additional_layer_dim if config.additional_layer > 0 else config.pos_dim + config.bert_config.hidden_size

        self.bert = AutoModel.from_pretrained(
            pretrained_model_name_or_path=bert_dir, config=config.bert_config)

        if config.pos_dim > 0:
            self.pos_embedding = nn.Embedding(num_pos_labels, config.pos_dim)
            nn.init.xavier_uniform_(self.pos_embedding.weight)
            if (config.additional_layer
                    and config.additional_layer_type != "lstm"
                    and config.pos_dim + config.bert_config.hidden_size !=
                    hidden_size):
                self.fuse_layer = nn.Linear(
                    config.pos_dim + config.bert_config.hidden_size,
                    hidden_size)
                nn.init.xavier_uniform_(self.fuse_layer.weight)
                self.fuse_layer.bias.data.zero_()
            else:
                self.fuse_layer = None
        else:
            self.pos_embedding = None

        if config.additional_layer > 0:
            if config.additional_layer_type == "transformer":
                new_config = deepcopy(config.bert_config)
                new_config.hidden_size = hidden_size
                new_config.num_hidden_layers = config.additional_layer
                new_config.hidden_dropout_prob = new_config.attention_probs_dropout_prob = config.mrc_dropout
                # new_config.attention_probs_dropout_prob = config.biaf_dropout  # todo add to hparams and tune
                self.additional_encoder = BertEncoder(new_config)
                self.additional_encoder.apply(self._init_bert_weights)
            else:
                assert hidden_size % 2 == 0, "Bi-LSTM need an even hidden_size"
                self.additional_encoder = StackedBidirectionalLstmSeq2SeqEncoder(
                    input_size=config.pos_dim + config.bert_config.hidden_size,
                    hidden_size=hidden_size // 2,
                    num_layers=config.additional_layer,
                    recurrent_dropout_probability=config.mrc_dropout,
                    use_highway=True)

        else:
            self.additional_encoder = None

        self._dropout = InputVariationalDropout(config.mrc_dropout)

        self.subtree_start_feedforward = FeedForward(
            hidden_size, 1, config.arc_representation_dim,
            Activation.by_name("elu")())
        self.subtree_end_feedforward = deepcopy(self.subtree_start_feedforward)

        # todo: equivalent to self-attention?
        self.subtree_start_attention = BilinearMatrixAttention(
            config.arc_representation_dim,
            config.arc_representation_dim,
            use_input_biases=True)
        self.subtree_end_attention = deepcopy(self.subtree_start_attention)

        # init linear children
        for layer in self.children():
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                if layer.bias is not None:
                    layer.bias.data.zero_()

    def _init_bert_weights(self, module):
        """ Initialize the weights. copy from transformers.BertPreTrainedModel"""
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(
                mean=0.0, std=self.config.bert_config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    @overrides
    def forward(
        self,  # type: ignore
        token_ids: torch.LongTensor,
        type_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        pos_tags: torch.LongTensor,
        word_mask: torch.BoolTensor,
        subtree_spans: torch.LongTensor = None,
    ):
        """  todo implement docstring
        Args:
            token_ids: [batch_size, num_word_pieces]
            type_ids: [batch_size, num_word_pieces]
            offsets: [batch_size, num_words, 2]
            wordpiece_mask: [batch_size, num_word_pieces]
            pos_tags: [batch_size, num_words]
            word_mask: [batch_size, num_words]
            subtree_spans: [batch_size, num_words, 2]
        Returns:
            span_start_logits: [batch_size, num_words, num_words]
            span_end_logits: [batch_size, num_words, num_words]
            span_loss: if subtree_spans is given.

        """
        # [bsz, seq_len, hidden]
        embedded_text_input = self.get_word_embedding(
            token_ids=token_ids,
            offsets=offsets,
            wordpiece_mask=wordpiece_mask,
            type_ids=type_ids,
        )
        if self.pos_embedding is not None:
            embedded_pos_tags = self.pos_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
            if self.fuse_layer is not None:
                embedded_text_input = self.fuse_layer(embedded_text_input)
        # todo compare normal dropout with InputVariationalDropout
        embedded_text_input = self._dropout(embedded_text_input)

        if self.additional_encoder is not None:
            if self.config.additional_layer_type == "transformer":
                extended_attention_mask = self.bert.get_extended_attention_mask(
                    word_mask, word_mask.size(), word_mask.device)
                encoded_text = self.additional_encoder(
                    hidden_states=embedded_text_input,
                    attention_mask=extended_attention_mask)[0]
            else:
                encoded_text = self.additional_encoder(
                    inputs=embedded_text_input, mask=word_mask)
        else:
            encoded_text = embedded_text_input

        batch_size, seq_len, encoding_dim = encoded_text.size()

        # [bsz, seq_len, dim]
        subtree_start_representation = self._dropout(
            self.subtree_start_feedforward(encoded_text))
        subtree_end_representation = self._dropout(
            self.subtree_end_feedforward(encoded_text))
        # [bsz, seq_len, seq_len]
        span_start_scores = self.subtree_start_attention(
            subtree_start_representation, subtree_start_representation)
        span_end_scores = self.subtree_end_attention(
            subtree_end_representation, subtree_end_representation)

        # start of word should less equal to it
        start_mask = word_mask.unsqueeze(-1) & (
            ~torch.triu(span_start_scores.bool(), 1))
        # end of word should greater equal to it.
        end_mask = word_mask.unsqueeze(-1) & torch.triu(span_end_scores.bool())

        minus_inf = -1e8
        span_start_scores = span_start_scores + (
            ~start_mask).float() * minus_inf
        span_end_scores = span_end_scores + (~end_mask).float() * minus_inf

        output = (F.log_softmax(span_start_scores,
                                dim=-1), F.log_softmax(span_end_scores,
                                                       dim=-1))
        if subtree_spans is not None:

            start_loss = F.cross_entropy(
                span_start_scores.view(batch_size * seq_len, -1),
                subtree_spans[:, :, 0].view(-1))
            end_loss = F.cross_entropy(
                span_end_scores.view(batch_size * seq_len, -1),
                subtree_spans[:, :, 1].view(-1))
            span_loss = start_loss + end_loss
            output = output + (span_loss, )

        return output

    def get_word_embedding(
        self,
        token_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        type_ids: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:  # type: ignore
        """
        get word-level embedding
        """
        # Shape: [batch_size, num_wordpieces, embedding_size].
        embeddings = self.bert(token_ids,
                               token_type_ids=type_ids,
                               attention_mask=wordpiece_mask)[0]

        # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size)
        # span_mask: (batch_size, num_orig_tokens, max_span_length)
        span_embeddings, span_mask = allennlp_util.batched_span_select(
            embeddings, offsets)
        span_mask = span_mask.unsqueeze(-1)
        span_embeddings *= span_mask  # zero out paddings

        span_embeddings_sum = span_embeddings.sum(2)
        span_embeddings_len = span_mask.sum(2)
        # Shape: (batch_size, num_orig_tokens, embedding_size)
        orig_embeddings = span_embeddings_sum / torch.clamp_min(
            span_embeddings_len, 1)

        # All the places where the span length is zero, write in zeros.
        orig_embeddings[(span_embeddings_len == 0).expand(
            orig_embeddings.shape)] = 0

        return orig_embeddings
Example #5
0
class BiaffineDependencyParser(nn.Module):
    """
    This dependency parser follows the model of
    [Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016)]
    (https://arxiv.org/abs/1611.01734) .
    But we use BERT for embedding
    """
    def __init__(self, bert_dir, config):
        super().__init__()

        self.config = config

        num_dep_labels = len(config.dep_tags)
        num_pos_labels = len(config.pos_tags)
        hidden_size = config.additional_layer_dim

        if config.pos_dim > 0:
            self.pos_embedding = nn.Embedding(num_pos_labels, config.pos_dim)
            nn.init.xavier_uniform_(self.pos_embedding.weight)
            if config.additional_layer_type != "lstm" and config.pos_dim + config.hidden_size != hidden_size:
                self.fuse_layer = nn.Linear(
                    config.pos_dim + config.hidden_size, hidden_size)
                nn.init.xavier_uniform_(self.fuse_layer.weight)
                self.fuse_layer.bias.data.zero_()
            else:
                self.fuse_layer = None
        else:
            self.pos_embedding = None

        if isinstance(config, BertDependencyConfig):
            self.bert = BertModel.from_pretrained(bert_dir, config=self.config)
        elif isinstance(config, RobertaDependencyConfig):
            self.bert = RobertaModel.from_pretrained(bert_dir,
                                                     config=self.config)

        if config.additional_layer > 0:
            if config.additional_layer_type == "transformer":
                new_config = deepcopy(config)
                new_config.hidden_size = hidden_size
                new_config.num_hidden_layers = config.additional_layer
                new_config.hidden_dropout_prob = config.biaf_dropout
                new_config.attention_probs_dropout_prob = config.biaf_dropout  # todo add to hparams and tune
                self.additional_encoder = BertEncoder(new_config)
                self.additional_encoder.apply(self._init_bert_weights)

            else:
                assert hidden_size % 2 == 0, "Bi-LSTM need an even hidden_size"
                self.additional_encoder = StackedBidirectionalLstmSeq2SeqEncoder(
                    input_size=config.pos_dim + config.hidden_size,
                    hidden_size=hidden_size // 2,
                    num_layers=config.additional_layer,
                    recurrent_dropout_probability=config.biaf_dropout,
                    use_highway=True)

        else:
            self.additional_encoder = None

        self.head_arc_feedforward = FeedForward(hidden_size, 1,
                                                config.arc_representation_dim,
                                                Activation.by_name("elu")())
        self.child_arc_feedforward = deepcopy(self.head_arc_feedforward)
        self.arc_attention = BilinearMatrixAttention(
            config.arc_representation_dim,
            config.arc_representation_dim,
            use_input_biases=True)

        self.head_tag_feedforward = FeedForward(hidden_size, 1,
                                                config.tag_representation_dim,
                                                Activation.by_name("elu")())
        self.child_tag_feedforward = deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = nn.modules.Bilinear(config.tag_representation_dim,
                                                config.tag_representation_dim,
                                                num_dep_labels)
        nn.init.xavier_uniform_(self.tag_bilinear.weight)
        self.tag_bilinear.bias.data.zero_()

        self._dropout = InputVariationalDropout(config.biaf_dropout)
        self._input_dropout = nn.Dropout(config.biaf_dropout)

        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, hidden_size]))

    def _init_bert_weights(self, module):
        """ Initialize the weights. copy from transformers.BertPreTrainedModel"""
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0,
                                       std=self.config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    @overrides
    def forward(
        self,  # type: ignore
        token_ids: torch.LongTensor,
        type_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        dep_idxs: torch.LongTensor,
        dep_tags: torch.LongTensor,
        pos_tags: torch.LongTensor,
        word_mask: torch.BoolTensor,
    ):

        embedded_text_input = self.get_word_embedding(
            token_ids=token_ids,
            offsets=offsets,
            wordpiece_mask=wordpiece_mask,
            type_ids=type_ids,
        )
        if self.pos_embedding is not None:
            embedded_pos_tags = self.pos_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
            if self.fuse_layer is not None:
                embedded_text_input = self.fuse_layer(embedded_text_input)
        # todo compare normal dropout with InputVariationalDropout
        embedded_text_input = self._input_dropout(embedded_text_input)

        if self.additional_encoder is not None:
            if self.config.additional_layer_type == "transformer":
                extended_attention_mask = self.bert.get_extended_attention_mask(
                    word_mask, word_mask.size(), word_mask.device)
                encoded_text = self.additional_encoder(
                    hidden_states=embedded_text_input,
                    attention_mask=extended_attention_mask)[0]
            else:
                encoded_text = self.additional_encoder(
                    inputs=embedded_text_input, mask=word_mask)
        else:
            encoded_text = embedded_text_input

        batch_size, _, encoding_dim = encoded_text.size()
        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        word_mask = torch.cat([word_mask.new_ones(batch_size, 1), word_mask],
                              1)
        dep_idxs = torch.cat([dep_idxs.new_zeros(batch_size, 1), dep_idxs], 1)
        dep_tags = torch.cat([dep_tags.new_zeros(batch_size, 1), dep_tags], 1)

        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(
            self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(
            self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(
            self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(
            self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = ~word_mask * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        if self.training:
            predicted_heads, predicted_head_tags = self._greedy_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, word_mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, word_mask)

        arc_nll, tag_nll = self._construct_loss(
            head_tag_representation=head_tag_representation,
            child_tag_representation=child_tag_representation,
            attended_arcs=attended_arcs,
            head_indices=dep_idxs,
            head_tags=dep_tags,
            mask=word_mask,
        )

        return predicted_heads, predicted_head_tags, arc_nll, tag_nll

    def get_word_embedding(
        self,
        token_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        type_ids: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:  # type: ignore
        """get word-level embedding"""
        # Shape: [batch_size, num_wordpieces, embedding_size].
        embeddings = self.bert(token_ids,
                               token_type_ids=type_ids,
                               attention_mask=wordpiece_mask)[0]

        # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size)
        # span_mask: (batch_size, num_orig_tokens, max_span_length)
        span_embeddings, span_mask = allennlp_util.batched_span_select(
            embeddings, offsets)
        span_mask = span_mask.unsqueeze(-1)
        span_embeddings *= span_mask  # zero out paddings

        span_embeddings_sum = span_embeddings.sum(2)
        span_embeddings_len = span_mask.sum(2)
        # Shape: (batch_size, num_orig_tokens, embedding_size)
        orig_embeddings = span_embeddings_sum / torch.clamp_min(
            span_embeddings_len, 1)

        # All the places where the span length is zero, write in zeros.
        orig_embeddings[(span_embeddings_len == 0).expand(
            orig_embeddings.shape)] = 0

        return orig_embeddings

    def _construct_loss(
        self,
        head_tag_representation: torch.Tensor,
        child_tag_representation: torch.Tensor,
        attended_arcs: torch.Tensor,
        head_indices: torch.Tensor,
        head_tags: torch.Tensor,
        mask: torch.BoolTensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the arc and tag loss for a sequence given gold head indices and tags.

        # Parameters

        head_tag_representation : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : `torch.Tensor`, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.
        head_indices : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length).
            The indices of the heads for every word.
        head_tags : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length).
            The dependency labels of the heads for every word.
        mask : `torch.BoolTensor`, required.
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.

        # Returns

        arc_nll : `torch.Tensor`, required.
            The negative log likelihood from the arc loss.
        tag_nll : `torch.Tensor`, required.
            The negative log likelihood from the arc tag loss.
        """
        batch_size, sequence_length, _ = attended_arcs.size()
        # shape (batch_size, 1)
        range_vector = get_range_vector(
            batch_size, get_device_of(attended_arcs)).unsqueeze(1)
        # shape (batch_size, sequence_length, sequence_length)
        normalised_arc_logits = (masked_log_softmax(attended_arcs, mask) *
                                 mask.unsqueeze(2) * mask.unsqueeze(1))

        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation,
                                              child_tag_representation,
                                              head_indices)
        normalised_head_tag_logits = masked_log_softmax(
            head_tag_logits, mask.unsqueeze(-1)) * mask.unsqueeze(-1)
        # index matrix with shape (batch, sequence_length)
        timestep_index = get_range_vector(sequence_length,
                                          get_device_of(attended_arcs))
        child_index = (timestep_index.view(1, sequence_length).expand(
            batch_size, sequence_length).long())
        # shape (batch_size, sequence_length)
        arc_loss = normalised_arc_logits[range_vector, child_index,
                                         head_indices]
        tag_loss = normalised_head_tag_logits[range_vector, child_index,
                                              head_tags]
        # We don't care about predictions for the symbolic ROOT token's head,
        # so we remove it from the loss.
        arc_loss = arc_loss[:, 1:]
        tag_loss = tag_loss[:, 1:]

        # The number of valid positions is equal to the number of unmasked elements minus
        # 1 per sequence in the batch, to account for the symbolic HEAD token.
        valid_positions = mask.sum() - batch_size

        arc_nll = -arc_loss.sum() / valid_positions.float()
        tag_nll = -tag_loss.sum() / valid_positions.float()
        return arc_nll, tag_nll

    def _greedy_decode(
        self,
        head_tag_representation: torch.Tensor,
        child_tag_representation: torch.Tensor,
        attended_arcs: torch.Tensor,
        mask: torch.BoolTensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions by decoding the unlabeled arcs
        independently for each word and then again, predicting the head tags of
        these greedily chosen arcs independently. Note that this method of decoding
        is not guaranteed to produce trees (i.e. there maybe be multiple roots,
        or cycles when children are attached to their parents).

        # Parameters

        head_tag_representation : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : `torch.Tensor`, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.

        # Returns

        heads : `torch.Tensor`
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : `torch.Tensor`
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the greedily decoded heads of each word.
        """
        # Mask the diagonal, because the head of a word can't be itself.
        attended_arcs = attended_arcs + torch.diag(
            attended_arcs.new(mask.size(1)).fill_(-np.inf))
        # Mask padded tokens, because we only want to consider actual words as heads.
        if mask is not None:
            minus_mask = ~mask.unsqueeze(2)
            attended_arcs.masked_fill_(minus_mask, -np.inf)

        # Compute the heads greedily.
        # shape (batch_size, sequence_length)
        _, heads = attended_arcs.max(dim=2)

        # Given the greedily predicted heads, decode their dependency tags.
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self._get_head_tags(head_tag_representation,
                                              child_tag_representation, heads)
        _, head_tags = head_tag_logits.max(dim=2)
        return heads, head_tags

    def _mst_decode(
        self,
        head_tag_representation: torch.Tensor,
        child_tag_representation: torch.Tensor,
        attended_arcs: torch.Tensor,
        mask: torch.BoolTensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions using the Edmonds' Algorithm
        for finding minimum spanning trees on directed graphs. Nodes in the
        graph are the words in the sentence, and between each pair of nodes,
        there is an edge in each direction, where the weight of the edge corresponds
        to the most likely dependency label probability for that arc. The MST is
        then generated from this directed graph.

        # Parameters

        head_tag_representation : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : `torch.Tensor`, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachments of a given word to all other words.

        # Returns

        heads : `torch.Tensor`
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : `torch.Tensor`
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the optimally decoded heads of each word.
        """
        batch_size, sequence_length, tag_representation_dim = head_tag_representation.size(
        )

        lengths = mask.data.sum(dim=1).long().cpu().numpy()

        expanded_shape = [
            batch_size, sequence_length, sequence_length,
            tag_representation_dim
        ]
        head_tag_representation = head_tag_representation.unsqueeze(2)
        head_tag_representation = head_tag_representation.expand(
            *expanded_shape).contiguous()
        child_tag_representation = child_tag_representation.unsqueeze(1)
        child_tag_representation = child_tag_representation.expand(
            *expanded_shape).contiguous()
        # Shape (batch_size, sequence_length, sequence_length, num_head_tags)
        pairwise_head_logits = self.tag_bilinear(head_tag_representation,
                                                 child_tag_representation)

        # Note that this log_softmax is over the tag dimension, and we don't consider pairs
        # of tags which are invalid (e.g are a pair which includes a padded element) anyway below.
        # Shape (batch, num_labels,sequence_length, sequence_length)
        normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits,
                                                        dim=3).permute(
                                                            0, 3, 1, 2)

        # Mask padded tokens, because we only want to consider actual words as heads.
        minus_inf = -1e8
        minus_mask = ~mask * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        # Shape (batch_size, sequence_length, sequence_length)
        normalized_arc_logits = F.log_softmax(attended_arcs,
                                              dim=2).transpose(1, 2)

        # Shape (batch_size, num_head_tags, sequence_length, sequence_length)
        # This energy tensor expresses the following relation:
        # energy[i,j] = "Score that i is the head of j". In this
        # case, we have heads pointing to their children.
        batch_energy = torch.exp(
            normalized_arc_logits.unsqueeze(1) +
            normalized_pairwise_head_logits)
        return self._run_mst_decoding(batch_energy, lengths)

    @staticmethod
    def _run_mst_decoding(
            batch_energy: torch.Tensor,
            lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        heads = []
        head_tags = []
        for energy, length in zip(batch_energy.detach().cpu(), lengths):
            scores, tag_ids = energy.max(dim=0)
            # Although we need to include the root node so that the MST includes it,
            # we do not want any word to be the parent of the root node.
            # Here, we enforce this by setting the scores for all word -> ROOT edges
            # edges to be 0.
            scores[0, :] = 0
            # Decode the heads. Because we modify the scores to prevent
            # adding in word -> ROOT edges, we need to find the labels ourselves.
            instance_heads, _ = decode_mst(scores.numpy(),
                                           length,
                                           has_labels=False)

            # Find the labels which correspond to the edges in the max spanning tree.
            instance_head_tags = []
            for child, parent in enumerate(instance_heads):
                instance_head_tags.append(tag_ids[parent, child].item())
            # We don't care what the head or tag is for the root token, but by default it's
            # not necessarily the same in the batched vs unbatched case, which is annoying.
            # Here we'll just set them to zero.
            instance_heads[0] = 0
            instance_head_tags[0] = 0
            heads.append(instance_heads)
            head_tags.append(instance_head_tags)
        return (
            torch.from_numpy(np.stack(heads)).to(batch_energy.device),
            torch.from_numpy(np.stack(head_tags)).to(batch_energy.device),
        )

    def _get_head_tags(
        self,
        head_tag_representation: torch.Tensor,
        child_tag_representation: torch.Tensor,
        head_indices: torch.Tensor,
    ) -> torch.Tensor:
        """
        Decodes the head tags given the head and child tag representations
        and a tensor of head indices to compute tags for. Note that these are
        either gold or predicted heads, depending on whether this function is
        being called to compute the loss, or if it's being called during inference.

        # Parameters

        head_tag_representation : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : `torch.Tensor`, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        head_indices : `torch.Tensor`, required.
            A tensor of shape (batch_size, sequence_length). The indices of the heads
            for every word.

        # Returns

        head_tag_logits : `torch.Tensor`
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            representing logits for predicting a distribution over tags
            for each arc.
        """
        batch_size = head_tag_representation.size(0)
        # shape (batch_size,)
        range_vector = get_range_vector(
            batch_size, get_device_of(head_tag_representation)).unsqueeze(1)

        # This next statement is quite a complex piece of indexing, which you really
        # need to read the docs to understand. See here:
        # https://docs.scipy.org/doc/np-1.13.0/reference/arrays.indexing.html#advanced-indexing
        # In effect, we are selecting the indices corresponding to the heads of each word from the
        # sequence length dimension for each element in the batch.

        # shape (batch_size, sequence_length, tag_representation_dim)
        selected_head_tag_representations = head_tag_representation[
            range_vector, head_indices]
        selected_head_tag_representations = selected_head_tag_representations.contiguous(
        )
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self.tag_bilinear(selected_head_tag_representations,
                                            child_tag_representation)
        return head_tag_logits