Exemple #1
0
class BertFine(BertPreTrainedModel):
    def __init__(self, bertConfig, num_classes):
        super(BertFine, self).__init__(bertConfig)
        self.bert = BertModel(bertConfig)  # bert模型
        self.dropout = nn.Dropout(bertConfig.hidden_dropout_prob)
        self.classifier = nn.Linear(in_features=bertConfig.hidden_size,
                                    out_features=num_classes)
        self.apply(self.init_bert_weights)
        # 默认情况下,bert encoder模型所有的参数都是参与训练的,32的batch_size大概8.7G显存
        # 可以通过以下设置为将其设为不训练,只将classifier这一层进行反响传播,32的batch_size大概显存1.1G
        self.unfreeze_bert_encoder()

    def freeze_bert_encoder(self):
        for p in self.bert.parameters():
            p.requires_grad = False

    def unfreeze_bert_encoder(self):
        for p in self.bert.parameters():
            p.requires_grad = True

    def forward(self,
                input_ids,
                token_type_ids,
                attention_mask,
                label_ids=None,
                output_all_encoded_layers=False):
        _, pooled_output = self.bert(
            input_ids,
            token_type_ids,
            attention_mask,
            output_all_encoded_layers=output_all_encoded_layers)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits
class BertForSiameseClassification(BertPreTrainedModel):
    def __init__(self, config):
        super(BertForSiameseClassification, self).__init__(config)
        self.bert = BertModel(config)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.classifier = torch.nn.Linear(config.hidden_size, 2)
        self.apply(self.init_bert_weights)
        self.avg_vec = AvgVec()

    def forward(self, input_ids_1, input_mask_1, input_ids_2, input_mask_2):
        self.bert.eval()
        encoder_layer_1, pooled_output_1 = self.bert(
            input_ids_1, token_type_ids=None, attention_mask=input_mask_1)
        encoder_layer_2, pooled_output_2 = self.bert(
            input_ids_2, token_type_ids=None, attention_mask=input_mask_2)
        out1 = self.avg_vec(encoder_layer_1, input_mask_1)
        out2 = self.avg_vec(encoder_layer_2, input_mask_2)

        out_norm = diff = torch.abs(out1 - out2)
        logit = self.classifier(out_norm)
        softmax = F.softmax(logit, dim=1)
        return logit, softmax

    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True
class BertForMultiLabelSequenceClassification(BertPreTrainedModel):
    def __init__(self, config, num_labels):
        super(BertForMultiLabelSequenceClassification, self).__init__(config)
        self.num_labels = num_labels
        self.bert = BertModel(config)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.classifier = torch.nn.Linear(config.hidden_size, num_labels)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
        _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        if labels is not None:
            #loss_fct = BCEWithLogitsLoss()
            #loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            return loss
        else:
            return logits
        
    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False
    
    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True
Exemple #4
0
class BertForMultiLabelSequenceClassification(BertPreTrainedModel):
    """BERT model for classification.
    This module is composed of the BERT model with a linear layer on top of
    the pooled output.
    """
    def __init__(self, config, num_labels=2):
        super(BertForMultiLabelSequenceClassification, self).__init__(config)
        self.num_labels = num_labels
        self.hidden_size = config.hidden_size
        self.mem_size = 512
        self.bert = BertModel(config)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)

        self.att = DocAttNet(sent_hidden_size=config.hidden_size, doc_hidden_size = self.mem_size, num_classes = num_labels)


        self.classifier = torch.nn.Linear(self.mem_size *2, num_labels)
        self.classifier2 = torch.nn.Linear(config.hidden_size, num_labels)
        self.apply(self.init_bert_weights)

    def forward2(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
        _, pooled_output = self.bert(input_id, token_type_id, attention_mask, output_all_encoded_layers=False)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier2(pooled_output)
        return logits


    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, long_doc=True):

        #import pdb; pdb.set_trace()
        if long_doc:
            #self.freeze_bert_encoder()
            zs = []
            for i in range(input_ids.shape[1]):
                _, pooled_output = self.bert(input_ids[:,i], token_type_ids[:,i], attention_mask[:,i], output_all_encoded_layers=False)
                #pooled_output = self.dropout(pooled_output)
                zs.append(pooled_output.detach())

            mem = torch.zeros(2, input_ids.shape[0], self.mem_size).cuda()

            attention_output, word_attn_norm = self.att( torch.stack(zs, 0), mem)
            attention_output = self.dropout(attention_output)
            logits = self.classifier(attention_output)
            return logits, word_attn_norm
        else:
            #self.unfreeze_bert_encoder()
            _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
            pooled_output = self.dropout(pooled_output)
            logits = self.classifier2(pooled_output)
            return logits

        
    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False
    
    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True
class BertForMultiLabelSequenceClassification(BertPreTrainedModel):
    def __init__(self, config, num_labels):
        super(BertForMultiLabelSequenceClassification, self).__init__(config)
        self.num_labels = num_labels
        self.bert = BertModel(config)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.classifier = torch.nn.Linear(config.hidden_size, num_labels)
        self.apply(self.init_bert_weights)

    @property
    def device(self) -> torch.device:
        return self.classifier.weight.device

    def forward(self,
                input_ids: torch.Tensor,
                token_type_ids=None,
                attention_mask: Optional[torch.Tensor] = None,
                labels=None):
        _, pooled_output = self.bert(input_ids,
                                     token_type_ids,
                                     attention_mask,
                                     output_all_encoded_layers=False)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        if labels is not None:
            loss_fct = BCEWithLogitsLoss()
            loss = loss_fct(logits.view(-1, self.num_labels),
                            labels.view(-1, self.num_labels))
            return loss
        else:
            return logits

    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True

    @classmethod
    def load(cls, modelpath, config, num_labels):
        print('loading model from [%s]' % modelpath, file=sys.stderr)
        model = cls(config, num_labels)
        model.load_state_dict(torch.load(modelpath))
        return model
Exemple #6
0
class BertFine(BertPreTrainedModel):
    def __init__(self,bertConfig,num_classes):
        super(BertFine ,self).__init__(bertConfig)
        self.bert = BertModel(bertConfig)
        self.dropout = nn.Dropout(bertConfig.hidden_dropout_prob)
        n = 1
        if config['feature-based'] == 'Finetune_All': n = bertConfig.num_hidden_layers
        elif config['feature-based'] == 'Second_to_Last': n = bertConfig.num_hidden_layers-1
        elif config['feature-based'] == 'Concat_Last_Four': n = 4
        self.pooler = BertFinalPooler(bertConfig.hidden_size, n)
        self.classifier = nn.Linear(in_features=bertConfig.hidden_size*n, out_features=num_classes)
        self.apply(self.init_bert_weights)
        self.unfreeze_bert_encoder() 

    def freeze_bert_encoder(self):
        for p in self.bert.parameters():
            p.requires_grad = False

    def unfreeze_bert_encoder(self):
        for p in self.bert.parameters():
            p.requires_grad = True

    def forward(self, input_ids, token_type_ids, attention_mask, label_ids=None, output_all_encoded_layers=True):
        encoded_layers, pooled_output = self.bert(input_ids,
                                        token_type_ids,
                                        attention_mask,
                                        output_all_encoded_layers=output_all_encoded_layers)
        
        if config['feature-based'] != 'Last':
            if config['feature-based'] == 'Finetune_All':
                sequence_output = torch.cat(encoded_layers,2)
            elif config['feature-based'] == 'First':
                sequence_output = encoded_layers[0]
            elif config['feature-based'] == 'Second_to_Last':
                sequence_output = torch.cat(encoded_layers[1:],1)
            elif config['feature-based'] == 'Sum_Last_Four':
                sequence_output = sum(encoded_layers[-4:])
            elif config['feature-based'] == 'Concat_Last_Four':
                sequence_output = torch.cat(encoded_layers[-4:],2)
            elif config['feature-based'] == 'Sum_All':
                sequence_output = sum(encoded_layers)
            pooled_output = self.pooler(sequence_output)

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        return logits
Exemple #7
0
class BertForMultiLabelSequenceClassification(BertPreTrainedModel
                                              ):  # type: ignore
    """Make a good docstring!"""
    def __init__(self, config: BertConfig, num_labels: int = 2):
        super(BertForMultiLabelSequenceClassification, self).__init__(config)
        self.num_labels = num_labels
        self.bert = BertModel(config)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.classifier = torch.nn.Linear(config.hidden_size, num_labels)
        self.apply(self.init_bert_weights)

    def forward(
        self,
        input_ids: Tensor,
        token_type_ids: Tensor = None,
        attention_mask: Tensor = None,
        labels: Tensor = None,
        pos_weight: Tensor = None,
    ) -> Tensor:
        _, pooled_output = self.bert(input_ids,
                                     token_type_ids,
                                     attention_mask,
                                     output_all_encoded_layers=False)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        if labels is not None:
            if pos_weight is None:
                loss_fct = BCEWithLogitsLoss()
            else:
                loss_fct = BCEWithLogitsLoss(pos_weight=pos_weight)
            loss = loss_fct(logits.view(-1, self.num_labels),
                            labels.view(-1, self.num_labels))
            return loss
        else:
            return logits

    def freeze_bert_encoder(self) -> None:
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self) -> None:
        for param in self.bert.parameters():
            param.requires_grad = True
class BertForMultiLabelSequenceClassification(BertPreTrainedModel,):
    """BERT model for classification.
    This module is composed of the BERT model with a linear layer on top of
    the pooled output.
    """
    def __init__(self, config, num_labels=17, mobilebert = True):
        self.mobilebert = mobilebert
        if not mobilebert:
            super(BertForMultiLabelSequenceClassification, self).__init__(config)
        else:
            super(BertForMultiLabelSequenceClassification, self).__init__(config)
        self.num_labels = num_labels
        self.bert = BertModel(config) if not mobilebert else MobileBertModel.from_pretrained(
            'google/mobilebert-uncased',
            num_labels=num_labels,)
        
        self.dropout = torch.nn.Dropout( config.hidden_dropout_prob)
        self.classifier = torch.nn.Linear( config.hidden_size, num_labels)
        if not mobilebert:
            self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
        _, pooled_output = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_all_encoded_layers=False)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        zeros = torch.zeros_like(logits)
        ones = torch.ones_like(logits)

        labels = labels.to(torch.float)
        loss_fct = BCEWithLogitsLoss()
        
        loss = loss_fct(logits, labels)
        return loss , logits

        
    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False
    
    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True
class BertForMultiLabelSequenceClassification(BertPreTrainedModel):
    """BERT model for classification.
    This module is composed of the BERT model with a linear layer on top of
    the pooled output.
    """
    def __init__(self, config, num_labels=2):
        super(BertForMultiLabelSequenceClassification, self).__init__(config)
        self.num_labels = num_labels
        self.weight = torch.tensor([0.1, 0.1, 0.2, 0.6])
        self.bert = BertModel(config)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.classifier = torch.nn.Linear(config.hidden_size, num_labels)
        self.apply(self.init_bert_weights)

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                labels=None):
        _, pooled_output = self.bert(input_ids,
                                     token_type_ids,
                                     attention_mask,
                                     output_all_encoded_layers=False)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        if labels is not None:
            loss_fct = torch.nn.BCEWithLogitsLoss()  #BCEWithLogitsLoss()
            loss = loss_fct(logits.view(-1, self.num_labels),
                            labels.view(-1, self.num_labels))
            return loss, logits
        else:
            return logits

    def freeze(self):
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze(self):
        for param in self.bert.parameters():
            param.requires_grad = True
class BertForSiameseClassification(BertPreTrainedModel):
    def __init__(self, config):
        super(BertForSiameseClassification, self).__init__(config)
        self.bert = BertModel(config)
        self.apply(self.init_bert_weights)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.cosVec = cosVec()

    def forward(self, input_ids_1, input_mask_1, input_ids_2, input_mask_2):
        encoder_layer_1, pooled_output_1 = self.bert(
            input_ids_1, token_type_ids=None, attention_mask=input_mask_1)
        encoder_layer_2, pooled_output_2 = self.bert(
            input_ids_2, token_type_ids=None, attention_mask=input_mask_2)
        sim = self.cosVec(pooled_output_1, pooled_output_2)
        return pooled_output_1, pooled_output_2, sim

    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True
Exemple #11
0
def add_adapters(bert_model: BertModel, config: AdapterConfig) -> BertModel:
    bert_encoder = bert_model.encoder
    for i in range(len(bert_model.encoder.layer)):
        bert_encoder.layer[i].attention.output = adapt_bert_self_output(
            config)(bert_encoder.layer[i].attention.output)

    # Freeze all parameters
    for param in bert_model.parameters():
        param.requires_grad = False
    # Unfreeze trainable parts — layer norms and adapters
    for name, sub_module in bert_model.named_modules():
        if isinstance(sub_module, (Adapter, BertLayerNorm)):
            for param_name, param in sub_module.named_parameters():
                param.requires_grad = True
    return bert_model
Exemple #12
0
    def __init__(self,
                 bert_config: str,
                 requires_grad: bool = False,
                 dropout: float = 0.1,
                 layer_dropout: float = 0.1,
                 combine_layers: str = "mix") -> None:
        model = BertModel(BertConfig.from_json_file(bert_config))

        for param in model.parameters():
            param.requires_grad = requires_grad

        super().__init__(bert_model=model,
                         layer_dropout=layer_dropout,
                         combine_layers=combine_layers)

        self.model = model
        self.dropout = dropout
        self.set_dropout(dropout)
class BertForLabelEncoding(PreTrainedBertModel):
    def __init__(self, config, trainable=False):
        super(BertForLabelEncoding, self).__init__(config)

        self.config = config
        self.bert = BertModel(config)
        #self.apply(self.init_bert_weights)     # don't need to perform due to pre-trained params loading

        if not trainable:
            for p in self.bert.parameters():
                p.requires_grad = False

    def forward(self,
                input_ids,
                token_type_ids,
                attention_mask,
                output_all_encoded_layers=False):
        _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
                                     output_all_encoded_layers)
        return pooled_output
Exemple #14
0
class BertLSTMForClassification(BertPreTrainedModel):

    def __init__(self, config, encoder, attention, hidden_dim, num_labels):
        super(BertForClassification, self).__init__(config)
        self.num_classes = num_labels
        self.bert = BertModel(config)

        self.encoder = encoder
        self.attention = attention
        self.decoder = nn.Linear(hidden_dim, num_labels)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
        encoded_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)

        outputs, hidden = self.encoder(encoded_layers)
        if isinstance(hidden, tuple): # LSTM
            hidden = hidden[1] # take the cell state

        if self.encoder.bidirectional: # need to concat the last 2 hidden layers
            hidden = torch.cat([hidden[-1], hidden[-2]], dim=1)
        else:
            hidden = hidden[-1]

        # max across T?
        # Other options (work worse on a few tests):
        # linear_combination, _ = torch.max(outputs, 0)
        # linear_combination = torch.mean(outputs, 0)

        energy, linear_combination = self.attention(hidden, outputs, outputs)
        logits = self.decoder(linear_combination)

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            return loss
        else:
            return logits

    def freeze_bert(self):
        for param in self.bert.parameters():
            param.requires_grad = False
Exemple #15
0
class BertForMultiLabelSequenceClassification(BertPreTrainedModel):
    """BERT model for classification.
    This module is composed of the BERT model with a linear layer on top of
    the pooled output.
    Params:
        `config`: a BertConfig class instance with the configuration to build a new model.
        `num_labels`: the number of classes for the classifier. Default = 2.
    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
            with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
            `extract_features.py`, `run_classifier.py` and `run_squad.py`)
        `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
            types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
            a `sentence B` token (see BERT paper for more details).
        `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
            selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
            input sequence length in the current batch. It's the mask that we typically use for attention when
            a batch has varying length sentences.
        `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, num_labels]
            with indices selected in [0, ..., num_labels].
    Outputs:
        if `labels` is not `None`:
            Outputs the CrossEntropy classification loss of the output with the labels.
        if `labels` is `None`:
            Outputs the classification logits of shape [batch_size, num_labels].
    """
    def __init__(self, config, num_labels=2, loss_fct="bbce"):
        super(BertForMultiLabelSequenceClassification, self).__init__(config)
        self.num_labels = num_labels
        self.loss_fct = loss_fct
        self.bert = BertModel(config)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.classifier = torch.nn.Linear(config.hidden_size, num_labels)
        self.apply(self.init_bert_weights)

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                labels=None):
        _, pooled_output = self.bert(input_ids,
                                     token_type_ids,
                                     attention_mask,
                                     output_all_encoded_layers=False)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        if labels is not None:
            if self.loss_fct == "bbce":
                loss_fct = BalancedBCEWithLogitsLoss()
            else:
                loss_fct = torch.nn.MultiLabelSoftMarginLoss()
            loss = loss_fct(logits.view(-1, self.num_labels),
                            labels.view(-1, self.num_labels))
            return loss
        else:
            return logits

    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True
class BertQAYesnoHierarchicalReinforceRACE(BertPreTrainedModel):
    """
    Hard attention using reinforce learning
    """
    def __init__(self,
                 config,
                 evidence_lambda=0.8,
                 num_choices=4,
                 sample_steps: int = 5,
                 reward_func: int = 0,
                 freeze_bert=False):
        super(BertQAYesnoHierarchicalReinforceRACE, self).__init__(config)
        logger.info(f'The model {self.__class__.__name__} is loading...')
        logger.info(f'The coefficient of evidence loss is {evidence_lambda}')
        logger.info(f'Currently the number of choices is {num_choices}')
        logger.info(f'Sample steps: {sample_steps}')
        logger.info(f'Reward function: {reward_func}')
        logger.info(f'If freeze BERT\'s parameters: {freeze_bert} ')

        layers.set_seq_dropout(True)
        layers.set_my_dropout_prob(config.hidden_dropout_prob)
        rep_layers.set_seq_dropout(True)
        rep_layers.set_my_dropout_prob(config.hidden_dropout_prob)

        self.bert = BertModel(config)

        if freeze_bert:
            for p in self.bert.parameters():
                p.requires_grad = False

        self.doc_sen_self_attn = rep_layers.LinearSelfAttention(
            config.hidden_size)
        self.que_self_attn = rep_layers.LinearSelfAttention(config.hidden_size)

        self.word_similarity = layers.AttentionScore(config.hidden_size,
                                                     250,
                                                     do_similarity=False)
        self.vector_similarity = layers.AttentionScore(config.hidden_size,
                                                       250,
                                                       do_similarity=False)

        # self.yesno_predictor = nn.Linear(config.hidden_size * 2, 3)
        self.classifier = nn.Linear(config.hidden_size * 2, 1)
        self.evidence_lam = evidence_lambda
        self.sample_steps = sample_steps
        self.reward_func = [self.reinforce_step,
                            self.reinforce_step_1][reward_func]
        self.num_choices = num_choices

        self.apply(self.init_bert_weights)

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                labels=None,
                sentence_span_list=None,
                sentence_ids=None,
                max_sentences: int = 0):
        flat_input_ids = input_ids.view(-1, input_ids.size(-1))
        flat_token_type_ids = token_type_ids.view(
            -1,
            token_type_ids.size(-1)) if token_type_ids is not None else None
        flat_attention_mask = attention_mask.view(
            -1,
            attention_mask.size(-1)) if attention_mask is not None else None
        sequence_output, _ = self.bert(flat_input_ids,
                                       flat_token_type_ids,
                                       flat_attention_mask,
                                       output_all_encoded_layers=False)

        # mask: 1 for masked value and 0 for true value
        # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask)
        doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \
            rep_layers.split_doc_sen_que(sequence_output, flat_token_type_ids, flat_attention_mask, sentence_span_list,
                                         max_sentences=max_sentences)

        batch, max_sen, doc_len = doc_sen_mask.size()

        que_vec = self.que_self_attn(que, que_mask).view(batch, 1, -1)

        doc = doc_sen.reshape(batch, max_sen * doc_len, -1)
        word_sim = self.word_similarity(que_vec,
                                        doc).view(batch * max_sen, doc_len)
        doc = doc_sen.reshape(batch * max_sen, doc_len, -1)
        doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len)
        word_hidden = rep_layers.masked_softmax(word_sim, doc_mask,
                                                dim=1).unsqueeze(1).bmm(doc)

        word_hidden = word_hidden.view(batch, max_sen, -1)

        doc_vecs = self.doc_sen_self_attn(doc,
                                          doc_mask).view(batch, max_sen, -1)

        sentence_sim = self.vector_similarity(que_vec, doc_vecs)
        if self.training:
            _sample_prob, _sample_log_prob = self.sample_one_hot(
                sentence_sim, sentence_mask)
            loss_and_reward, _ = self.reward_func(word_hidden, que_vec, labels,
                                                  _sample_prob,
                                                  _sample_log_prob)
            output_dict = {'loss': loss_and_reward}
        else:
            _prob, _ = self.sample_one_hot(sentence_sim, sentence_mask)
            loss, _choice_logits = self.simple_step(word_hidden, que_vec,
                                                    labels, _prob)
            sentence_scores = rep_layers.masked_softmax(sentence_sim,
                                                        sentence_mask,
                                                        dim=-1).squeeze_(1)
            output_dict = {
                'sentence_logits': sentence_scores.float(),
                'loss': loss,
                'choice_logits': _choice_logits.float()
            }

        return output_dict

    def sample_one_hot(self, _similarity, _mask):
        _probability = rep_layers.masked_softmax(_similarity, _mask)
        dtype = _probability.dtype
        _probability = _probability.float()
        # _log_probability = masked_log_softmax(_similarity, _mask)
        if self.training:
            _distribution = Categorical(_probability)
            _sample_index = _distribution.sample((self.sample_steps, ))
            logger.debug(str(_sample_index.size()))
            new_shape = (self.sample_steps, ) + _similarity.size()
            logger.debug(str(new_shape))
            _sample_one_hot = F.one_hot(_sample_index,
                                        num_classes=_similarity.size(-1))
            # _sample_one_hot = _similarity.new_zeros(new_shape).scatter(-1, _sample_index.unsqueeze(-1), 1.0)
            logger.debug(str(_sample_one_hot.size()))
            _log_prob = _distribution.log_prob(
                _sample_index)  # sample_steps, batch, 1
            assert _log_prob.size() == new_shape[:-1], (_log_prob.size(),
                                                        new_shape)
            _sample_one_hot = _sample_one_hot.transpose(
                0, 1)  # batch, sample_steps, 1, max_sen
            _log_prob = _log_prob.transpose(0, 1)  # batch, sample_steps, 1
            return _sample_one_hot.to(dtype=dtype), _log_prob.to(dtype=dtype)
        else:
            _max_index = _probability.float().max(dim=-1, keepdim=True)[1]
            _one_hot = torch.zeros_like(_similarity).scatter_(
                -1, _max_index, 1.0)
            # _log_prob = _log_probability.gather(-1, _max_index)
            return _one_hot, None

    def reinforce_step(self, hidden, q_vec, label, prob, log_prob):
        batch, max_sen, hidden_dim = hidden.size()
        assert q_vec.size() == (batch, 1, hidden_dim)
        assert prob.size() == (batch, self.sample_steps, 1, max_sen)
        assert log_prob.size() == (batch, self.sample_steps, 1)
        expanded_hidden = hidden.unsqueeze(1).expand(-1, self.sample_steps, -1,
                                                     -1)
        h = prob.matmul(expanded_hidden).squeeze(
            2)  # batch, sample_steps, hidden_dim
        q = q_vec.expand(-1, self.sample_steps, -1)
        # _logits = self.classifier(torch.cat([h, q], dim=2)).view(-1, self.num_choices)  # batch, sample_steps, 3
        # Note the rank of dimension here
        _logits = self.classifier(torch.cat([h, q], dim=2)).view(label.size(0), self.num_choices, self.sample_steps)\
            .transpose(1, 2).reshape(-1, self.num_choices)
        expanded_label = label.unsqueeze(1).expand(
            -1, self.sample_steps).reshape(-1)
        _loss = F.cross_entropy(_logits, expanded_label)
        corrects = (_logits.max(dim=-1)[1] == expanded_label).to(hidden.dtype)
        log_prob = log_prob.reshape(label.size(0), self.num_choices,
                                    self.sample_steps).transpose(
                                        1, 2).mean(dim=-1)
        reward1 = (log_prob.reshape(-1) *
                   corrects).sum() / (self.sample_steps * label.size(0))
        return _loss - reward1, _logits

    def reinforce_step_1(self, hidden, q_vec, label, prob, log_prob):
        batch, max_sen, hidden_dim = hidden.size()
        assert q_vec.size() == (batch, 1, hidden_dim)
        assert prob.size() == (batch, self.sample_steps, 1, max_sen)
        assert log_prob.size() == (batch, self.sample_steps, 1)
        expanded_hidden = hidden.unsqueeze(1).expand(-1, self.sample_steps, -1,
                                                     -1)
        h = prob.matmul(expanded_hidden).squeeze(
            2)  # batch, sample_steps, hidden_dim
        q = q_vec.expand(-1, self.sample_steps, -1)
        # _logits = self.classifier(torch.cat([h, q], dim=2)).view(-1, self.num_choices)  # batch * sample_steps, 3
        _logits = self.classifier(torch.cat([h, q], dim=2)).view(label.size(0), self.num_choices, self.sample_steps)\
            .transpose(1, 2).reshape(-1, self.num_choices)
        expanded_label = label.unsqueeze(1).expand(
            -1, self.sample_steps).reshape(-1)  # batch * sample_steps

        _loss = F.cross_entropy(_logits, expanded_label)

        _final_log_prob = F.log_softmax(_logits, dim=-1)
        # ignore_mask = (expanded_label == -1)
        # expanded_label = expanded_label.masked_fill(ignore_mask, 0)
        selected_log_prob = _final_log_prob.gather(
            1, expanded_label.unsqueeze(1)).squeeze(-1)  # batch * sample_steps
        assert selected_log_prob.size() == (
            label.size(0) * self.sample_steps, ), selected_log_prob.size()
        log_prob = log_prob.reshape(label.size(0), self.num_choices,
                                    self.sample_steps).transpose(
                                        1, 2).mean(dim=-1)
        # reward2 = - (log_prob.reshape(-1) * (selected_log_prob * (1 - ignore_mask).to(log_prob.dtype))).sum() / (
        #         self.sample_steps * batch)
        reward2 = -(log_prob.reshape(-1) * selected_log_prob).sum() / (
            self.sample_steps * label.size(0))

        return _loss - reward2, _logits

    def simple_step(self, hidden, q_vec, label, prob):
        batch, max_sen, hidden_dim = hidden.size()
        assert q_vec.size() == (batch, 1, hidden_dim)
        assert prob.size() == (batch, 1, max_sen)
        h = prob.bmm(hidden)
        _logits = self.classifier(torch.cat([h, q_vec],
                                            dim=2)).view(-1, self.num_choices)
        if label is not None:
            _loss = F.cross_entropy(_logits, label)
        else:
            _loss = _logits.new_zeros(1)
        return _loss, _logits
class BertQAYesnoHierarchicalHardRACE(BertPreTrainedModel):
    """
    Hard:
    Hard attention, using gumbel softmax of reinforcement learning.
    """
    def __init__(self,
                 config,
                 evidence_lambda=0.8,
                 num_choices=4,
                 use_gumbel=True,
                 freeze_bert=False):
        super(BertQAYesnoHierarchicalHardRACE, self).__init__(config)
        logger.info(f'The model {self.__class__.__name__} is loading...')
        logger.info(f'The coefficient of evidence loss is {evidence_lambda}')
        logger.info(f'Currently the number of choices is {num_choices}')
        logger.info(f'Use gumbel: {use_gumbel}')
        logger.info(f'If freeze BERT\'s parameters: {freeze_bert} ')

        layers.set_seq_dropout(True)
        layers.set_my_dropout_prob(config.hidden_dropout_prob)
        rep_layers.set_seq_dropout(True)
        rep_layers.set_my_dropout_prob(config.hidden_dropout_prob)

        self.bert = BertModel(config)

        if freeze_bert:
            for p in self.bert.parameters():
                p.requires_grad = False

        # self.doc_sen_self_attn = layers.LinearSelfAttnAllennlp(config.hidden_size)
        # self.que_self_attn = layers.LinearSelfAttn(config.hidden_size)
        self.doc_sen_self_attn = rep_layers.LinearSelfAttention(
            config.hidden_size)
        self.que_self_attn = rep_layers.LinearSelfAttention(config.hidden_size)

        self.word_similarity = layers.AttentionScore(config.hidden_size,
                                                     250,
                                                     do_similarity=False)
        self.vector_similarity = layers.AttentionScore(config.hidden_size,
                                                       250,
                                                       do_similarity=False)

        self.classifier = nn.Linear(config.hidden_size * 2, 1)
        self.evidence_lam = evidence_lambda
        self.use_gumbel = use_gumbel
        self.num_choices = num_choices

        self.apply(self.init_bert_weights)

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                labels=None,
                sentence_span_list=None,
                sentence_ids=None,
                max_sentences: int = 0):
        flat_input_ids = input_ids.view(-1, input_ids.size(-1))
        flat_token_type_ids = token_type_ids.view(
            -1,
            token_type_ids.size(-1)) if token_type_ids is not None else None
        flat_attention_mask = attention_mask.view(
            -1,
            attention_mask.size(-1)) if attention_mask is not None else None
        sequence_output, _ = self.bert(flat_input_ids,
                                       flat_token_type_ids,
                                       flat_attention_mask,
                                       output_all_encoded_layers=False)

        # mask: 1 for masked value and 0 for true value
        # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask)
        doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \
            rep_layers.split_doc_sen_que(sequence_output, flat_token_type_ids, flat_attention_mask, sentence_span_list,
                                         max_sentences=max_sentences)

        batch, max_sen, doc_len = doc_sen_mask.size()
        # que_len = que_mask.size(1)

        # que_vec = layers.weighted_avg(que, self.que_self_attn(que, que_mask)).view(batch, 1, -1)
        que_vec = self.que_self_attn(que, que_mask).view(batch, 1, -1)

        doc = doc_sen.reshape(batch, max_sen * doc_len, -1)
        word_sim = self.word_similarity(que_vec,
                                        doc).view(batch * max_sen, doc_len)
        doc = doc_sen.reshape(batch * max_sen, doc_len, -1)
        doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len)
        word_hidden = rep_layers.masked_softmax(word_sim, doc_mask,
                                                dim=1).unsqueeze(1).bmm(doc)

        word_hidden = word_hidden.view(batch, max_sen, -1)

        doc_vecs = self.doc_sen_self_attn(doc,
                                          doc_mask).view(batch, max_sen, -1)

        sentence_sim = self.vector_similarity(que_vec, doc_vecs)
        sentence_hidden = self.hard_sample(
            sentence_sim,
            use_gumbel=self.use_gumbel,
            dim=-1,
            hard=True,
            mask=sentence_mask).bmm(word_hidden).squeeze(1)

        choice_logits = self.classifier(
            torch.cat([sentence_hidden, que_vec.squeeze(1)],
                      dim=1)).reshape(-1, self.num_choices)

        sentence_scores = rep_layers.masked_softmax(sentence_sim,
                                                    sentence_mask,
                                                    dim=-1).squeeze_(1)
        output_dict = {
            'choice_logits':
            choice_logits.float(),
            'sentence_logits':
            sentence_scores.reshape(choice_logits.size(0), self.num_choices,
                                    max_sen).detach().cpu().float(),
        }
        loss = 0
        if labels is not None:
            choice_loss = F.cross_entropy(choice_logits, labels)
            loss += choice_loss
        if sentence_ids is not None:
            log_sentence_sim = rep_layers.masked_log_softmax(
                sentence_sim.squeeze(1), sentence_mask, dim=-1)
            sentence_loss = F.nll_loss(log_sentence_sim,
                                       sentence_ids.view(batch),
                                       reduction='sum',
                                       ignore_index=-1)
            loss += self.evidence_lam * sentence_loss / choice_logits.size(0)
        output_dict['loss'] = loss
        return output_dict

    def hard_sample(self, logits, use_gumbel, dim=-1, hard=True, mask=None):
        if use_gumbel:
            if self.training:
                probs = rep_layers.gumbel_softmax(logits,
                                                  mask=mask,
                                                  hard=hard,
                                                  dim=dim)
                return probs
            else:
                probs = rep_layers.masked_softmax(logits, mask, dim=dim)
                index = probs.max(dim, keepdim=True)[1]
                y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
                return y_hard
        else:
            pass
Exemple #18
0
class SANBertNetwork(nn.Module):
    def __init__(self, opt, bert_config=None):
        super(SANBertNetwork, self).__init__()
        self.dropout_list = nn.ModuleList()
        self.bert_config = BertConfig.from_dict(opt)
        self.bert = BertModel(self.bert_config)
        if opt.get('dump_feature', False):
            self.opt = opt
            return
        if opt['update_bert_opt'] > 0:
            for p in self.bert.parameters():
                p.requires_grad = False
        mem_size = self.bert_config.hidden_size
        self.decoder_opt = opt['answer_opt']
        self.scoring_list = nn.ModuleList()
        labels = [int(ls) for ls in opt['label_size'].split(',')]
        task_dropout_p = opt['tasks_dropout_p']
        self.bert_pooler = None

        for task, lab in enumerate(labels):
            decoder_opt = self.decoder_opt[task]
            dropout = DropoutWrapper(task_dropout_p[task], opt['vb_dropout'])
            self.dropout_list.append(dropout)
            if decoder_opt == 1:
                out_proj = SANClassifier(mem_size, mem_size, lab, opt, prefix='answer', dropout=dropout)
                self.scoring_list.append(out_proj)
            else:
                out_proj = nn.Linear(self.bert_config.hidden_size, lab)
                self.scoring_list.append(out_proj)

        self.opt = opt
        self._my_init()
        self.set_embed(opt)

    def _my_init(self):
        def init_weights(module):
            if isinstance(module, (nn.Linear, nn.Embedding)):
                # Slightly different from the TF version which uses truncated_normal for initialization
                # cf https://github.com/pytorch/pytorch/pull/5617
                module.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range * self.opt['init_ratio'])
            elif isinstance(module, BertLayerNorm):
                # Slightly different from the BERT pytorch version, which should be a bug.
                # Note that it only affects on training from scratch. For detailed discussions, please contact xiaodl@.
                # Layer normalization (https://arxiv.org/abs/1607.06450)
                # support both old/latest version
                if 'beta' in dir(module) and 'gamma' in dir(module):
                    module.beta.data.zero_()
                    module.gamma.data.fill_(1.0)
                else:
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
            if isinstance(module, nn.Linear):
                module.bias.data.zero_()

        self.apply(init_weights)

    def nbert_layer(self):
        return len(self.bert.encoder.layer)

    def freeze_layers(self, max_n):
        assert max_n < self.nbert_layer()
        for i in range(0, max_n):
            self.freeze_layer(i)

    def freeze_layer(self, n):
        assert n < self.nbert_layer()
        layer = self.bert.encoder.layer[n]
        for p in layer.parameters():
            p.requires_grad = False

    def set_embed(self, opt):
        bert_embeddings = self.bert.embeddings
        emb_opt = opt['embedding_opt']
        if emb_opt == 1:
            for p in bert_embeddings.word_embeddings.parameters():
                p.requires_grad = False
        elif emb_opt == 2:
            for p in bert_embeddings.position_embeddings.parameters():
                p.requires_grad = False
        elif emb_opt == 3:
            for p in bert_embeddings.token_type_embeddings.parameters():
                p.requires_grad = False
        elif emb_opt == 4:
            for p in bert_embeddings.token_type_embeddings.parameters():
                p.requires_grad = False
            for p in bert_embeddings.position_embeddings.parameters():
                p.requires_grad = False

    def forward(self, input_ids, token_type_ids, attention_mask, premise_mask=None, hyp_mask=None, task_id=0):
        all_encoder_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
        sequence_output = all_encoder_layers[-1]
        if self.bert_pooler is not None:
            pooled_output = self.bert_pooler(sequence_output)
        decoder_opt = self.decoder_opt[task_id]
        if decoder_opt == 1:
            max_query = hyp_mask.size(1)
            assert max_query > 0
            assert premise_mask is not None
            assert hyp_mask is not None
            hyp_mem = sequence_output[:, :max_query, :]
            logits = self.scoring_list[task_id](sequence_output, hyp_mem, premise_mask, hyp_mask)
        else:
            pooled_output = self.dropout_list[task_id](pooled_output)
            logits = self.scoring_list[task_id](pooled_output)
        return logits
class Bert_CRF(BertPreTrainedModel):
    def __init__(self, config, num_tag):
        super(Bert_CRF, self).__init__(config)
        self.bert = BertModel(config)
        if args.do_not_train_ernie:
            for p in self.bert.parameters():
                p.requires_grad = False
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, num_tag)
        self.apply(self.init_bert_weights)
        self.crf = CRF(num_tag)
        self.num_tag = num_tag

    def forward(self,
                input_ids,
                token_type_ids,
                attention_mask,
                label_id=None,
                output_all_encoded_layers=False):
        bert_encode, _ = self.bert(
            input_ids,
            token_type_ids,
            attention_mask,
            output_all_encoded_layers=output_all_encoded_layers)
        output = self.classifier(bert_encode)

        return output

    def loss_fn(self, bert_encode, output_mask, tags):
        if args.do_CRF:
            loss = self.crf.negative_log_loss(bert_encode, output_mask, tags)
        else:
            loss = torch.autograd.Variable(torch.tensor(0.),
                                           requires_grad=True)
            for ix, (features, tag) in enumerate(zip(bert_encode, tags)):
                num_valid = torch.sum(output_mask[ix].detach())
                features = features[output_mask[ix] == 1]
                tag = tag[:num_valid]
                loss_fct = nn.CrossEntropyLoss(ignore_index=0)
                loss = loss + loss_fct(
                    features.view(-1, self.num_tag).cpu(),
                    tag.view(-1).cpu())
        return loss

    def predict(self, bert_encode, output_mask):
        if args.do_CRF:
            predicts = self.crf.get_batch_best_path(bert_encode, output_mask)
            if not args.do_inference:
                predicts = predicts.view(1, -1).squeeze()
                predicts = predicts[predicts != -1]
            else:
                predicts_ = []
                for ix, features, in enumerate(predicts):
                    #features = features[output_mask[ix] == 1]
                    predict = features[features != -1]
                    predicts_.append(predict)
                predicts = predicts_
        else:
            predicts_ = []
            for ix, features, in enumerate(bert_encode):
                features = features[output_mask[ix] == 1]
                predict = F.softmax(features, dim=1)
                predict = torch.argmax(predict, dim=1)
                predicts_.append(predict)
            if not args.do_inference:
                predicts = torch.cat(predicts_, 0)
            else:
                predicts = predicts_
        return predicts

    def acc_f1(self, y_pred, y_true):
        try:
            y_pred = y_pred.numpy()
            y_true = y_true.numpy()
        except:
            pass
        f1 = f1_score(y_true, y_pred, average="macro")
        correct = np.sum((y_true == y_pred).astype(int))
        acc = correct / y_pred.shape[0]
        return acc, f1

    def class_report(self, y_pred, y_true):
        y_true = y_true.numpy()
        y_pred = y_pred.numpy()
        classify_report = classification_report(y_true, y_pred)
        print('\n\nclassify_report:\n', classify_report)
class BertQAYesnoHierarchicalReinforce(BertPreTrainedModel):
    """
    Hard attention using reinforce learning
    """
    def __init__(self,
                 config,
                 evidence_lambda=0.8,
                 sample_steps: int = 5,
                 reward_func: int = 0,
                 freeze_bert=False):
        super(BertQAYesnoHierarchicalReinforce, self).__init__(config)
        logger.info(f'The model {self.__class__.__name__} is loading...')
        logger.info(f'The coefficient of evidence loss is {evidence_lambda}')
        logger.info(f'Sample steps: {sample_steps}')
        logger.info(f'Reward function: {reward_func}')
        logger.info(f'If freeze BERT\'s parameters: {freeze_bert} ')

        layers.set_seq_dropout(True)
        layers.set_my_dropout_prob(config.hidden_dropout_prob)

        self.bert = BertModel(config)

        if freeze_bert:
            for p in self.bert.parameters():
                p.requires_grad = False

        self.doc_sen_self_attn = layers.LinearSelfAttnAllennlp(
            config.hidden_size)
        self.que_self_attn = layers.LinearSelfAttn(config.hidden_size)

        self.word_similarity = layers.AttentionScore(config.hidden_size,
                                                     250,
                                                     do_similarity=False)
        self.vector_similarity = layers.AttentionScore(config.hidden_size,
                                                       250,
                                                       do_similarity=False)

        self.yesno_predictor = nn.Linear(config.hidden_size * 2, 3)
        self.evidence_lam = evidence_lambda
        self.sample_steps = sample_steps
        self.reward_func = [self.reinforce_step,
                            self.reinforce_step_1][reward_func]

        self.apply(self.init_bert_weights)

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                answer_choice=None,
                sentence_span_list=None,
                sentence_ids=None):
        sequence_output, _ = self.bert(input_ids,
                                       token_type_ids,
                                       attention_mask,
                                       output_all_encoded_layers=False)

        # mask: 1 for masked value and 0 for true value
        # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask)
        doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \
            layers.split_doc_sen_que(sequence_output, token_type_ids, attention_mask, sentence_span_list)

        batch, max_sen, doc_len = doc_sen_mask.size()
        # que_len = que_mask.size(1)

        que_vec = layers.weighted_avg(que, self.que_self_attn(que,
                                                              que_mask)).view(
                                                                  batch, 1, -1)

        doc = doc_sen.reshape(batch, max_sen * doc_len, -1)
        # [batch, max_sen, doc_len] -> [batch * max_sen, doc_len]
        word_sim = self.word_similarity(que_vec,
                                        doc).view(batch * max_sen, doc_len)

        doc = doc_sen.reshape(batch * max_sen, doc_len, -1)
        doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len)
        # [batch * max_sen, doc_len] -> [batch * max_sen, 1, doc_len] -> [batch * max_sen, 1, h]
        word_hidden = masked_softmax(word_sim, 1 - doc_mask,
                                     dim=1).unsqueeze(1).bmm(doc)
        word_hidden = word_hidden.view(batch, max_sen, -1)

        doc_vecs = layers.weighted_avg(doc,
                                       self.doc_sen_self_attn(doc,
                                                              doc_mask)).view(
                                                                  batch,
                                                                  max_sen, -1)

        # [batch, 1, h]
        # sentence_hidden = self.vector_similarity(que_vec, doc_vecs, x2_mask=sentence_mask, x3=word_hidden).squeeze(1)
        # [batch, 1, max_sen]
        sentence_sim = self.vector_similarity(que_vec, doc_vecs)
        # sentence_hidden = self.hard_sample(sentence_sim, use_gumbel=self.use_gumbel, dim=-1,
        #                                    hard=True, mask=(1 - sentence_mask)).bmm(word_hidden).squeeze(1)
        if self.training:
            _sample_prob, _sample_log_prob = self.sample_one_hot(
                sentence_sim, 1 - sentence_mask)
            loss_and_reward, _ = self.reward_func(word_hidden, que_vec,
                                                  answer_choice, _sample_prob,
                                                  _sample_log_prob)
            output_dict = {'loss': loss_and_reward}
        else:
            _prob, _ = self.sample_one_hot(sentence_sim, 1 - sentence_mask)
            loss, _yesno_logits = self.simple_step(word_hidden, que_vec,
                                                   answer_choice, _prob)
            sentence_scores = masked_softmax(sentence_sim,
                                             1 - sentence_mask,
                                             dim=-1).squeeze_(1)
            output_dict = {
                'max_weight': sentence_scores.max(dim=1)[0],
                'max_weight_index': sentence_scores.max(dim=1)[1],
                'sentence_logits': sentence_scores,
                'loss': loss,
                'yesno_logits': _yesno_logits
            }

        return output_dict

        # yesno_logits = self.yesno_predictor(torch.cat([sentence_hidden, que_vec.squeeze(1)], dim=1))
        #
        # sentence_scores = masked_softmax(sentence_sim, 1 - sentence_mask, dim=-1).squeeze_(1)
        # output_dict = {'yesno_logits': yesno_logits,
        #                'sentence_logits': sentence_scores,
        #                'max_weight_index': sentence_scores.max(dim=1)[1],
        #                'max_weight': sentence_scores.max(dim=1)[0]}
        # loss = 0
        # if answer_choice is not None:
        #     choice_loss = F.cross_entropy(yesno_logits, answer_choice, ignore_index=-1)
        #     loss += choice_loss
        # if sentence_ids is not None:
        #     log_sentence_sim = masked_log_softmax(sentence_sim.squeeze(1), 1 - sentence_mask, dim=-1)
        #     sentence_loss = self.evidence_lam * F.nll_loss(log_sentence_sim, sentence_ids, ignore_index=-1)
        #     loss += sentence_loss
        # output_dict['loss'] = loss
        # return output_dict

    def sample_one_hot(self, _similarity, _mask):
        _probability = masked_softmax(_similarity, _mask)
        # _log_probability = masked_log_softmax(_similarity, _mask)
        if self.training:
            _distribution = Categorical(_probability)
            _sample_index = _distribution.sample((self.sample_steps, ))
            new_shape = (self.sample_steps, ) + _similarity.size()
            _sample_one_hot = _similarity.new_zeros(new_shape).scatter(
                -1, _sample_index.unsqueeze(-1), 1.0)
            _log_prob = _distribution.log_prob(
                _sample_index)  # sample_steps, batch, 1
            assert _log_prob.size() == new_shape[:-1], (_log_prob.size(),
                                                        new_shape)
            _sample_one_hot = _sample_one_hot.transpose(
                0, 1)  # batch, sample_steps, 1, max_sen
            _log_prob = _log_prob.transpose(0, 1)  # batch, sample_steps, 1
            return _sample_one_hot, _log_prob
        else:
            _max_index = _probability.max(dim=-1, keepdim=True)[1]
            _one_hot = torch.zeros_like(_similarity).scatter_(
                -1, _max_index, 1.0)
            # _log_prob = _log_probability.gather(-1, _max_index)
            return _one_hot, None

    def reinforce_step(self, hidden, q_vec, label, prob, log_prob):
        batch, max_sen, hidden_dim = hidden.size()
        assert q_vec.size() == (batch, 1, hidden_dim)
        assert prob.size() == (batch, self.sample_steps, 1, max_sen)
        assert log_prob.size() == (batch, self.sample_steps, 1)
        expanded_hidden = hidden.unsqueeze(1).expand(-1, self.sample_steps, -1,
                                                     -1)
        h = prob.matmul(expanded_hidden).squeeze(
            2)  # batch, sample_steps, hidden_dim
        q = q_vec.expand(-1, self.sample_steps, -1)
        _logits = self.yesno_predictor(torch.cat([h, q], dim=2)).view(
            -1, 3)  # batch, sample_steps, 3
        expanded_label = label.unsqueeze(1).expand(
            -1, self.sample_steps).reshape(-1)
        _loss = F.cross_entropy(_logits, expanded_label)
        corrects = (_logits.max(dim=-1)[1] == expanded_label).to(hidden.dtype)
        reward1 = (log_prob.reshape(-1) *
                   corrects).sum() / (self.sample_steps * batch)
        return _loss - reward1, _logits

    def reinforce_step_1(self, hidden, q_vec, label, prob, log_prob):
        batch, max_sen, hidden_dim = hidden.size()
        assert q_vec.size() == (batch, 1, hidden_dim)
        assert prob.size() == (batch, self.sample_steps, 1, max_sen)
        assert log_prob.size() == (batch, self.sample_steps, 1)
        expanded_hidden = hidden.unsqueeze(1).expand(-1, self.sample_steps, -1,
                                                     -1)
        h = prob.matmul(expanded_hidden).squeeze(
            2)  # batch, sample_steps, hidden_dim
        q = q_vec.expand(-1, self.sample_steps, -1)
        _logits = self.yesno_predictor(torch.cat([h, q], dim=2)).view(
            -1, 3)  # batch * sample_steps, 3
        expanded_label = label.unsqueeze(1).expand(
            -1, self.sample_steps).reshape(-1)  # batch * sample_steps

        _loss = F.cross_entropy(_logits, expanded_label)

        _final_log_prob = F.log_softmax(_logits, dim=-1)
        ignore_mask = (expanded_label == -1)
        expanded_label = expanded_label.masked_fill(ignore_mask, 0)
        selected_log_prob = _final_log_prob.gather(
            1, expanded_label.unsqueeze(1)).squeeze(-1)
        assert selected_log_prob.size() == (
            batch * self.sample_steps, ), selected_log_prob.size()
        reward2 = -(log_prob.reshape(-1) *
                    (selected_log_prob *
                     (1 - ignore_mask).to(log_prob.dtype))).sum() / (
                         self.sample_steps * batch)

        return _loss - reward2, _logits

    def simple_step(self, hidden, q_vec, label, prob):
        batch, max_sen, hidden_dim = hidden.size()
        assert q_vec.size() == (batch, 1, hidden_dim)
        assert prob.size() == (batch, 1, max_sen)
        h = prob.bmm(hidden)
        _logits = self.yesno_predictor(torch.cat([h, q_vec],
                                                 dim=2)).view(-1, 3)
        if label is not None:
            _loss = F.cross_entropy(_logits, label)
        else:
            _loss = _logits.new_zeros(1)
        return _loss, _logits
Exemple #21
0
class BertForMultiLabelClassification(BertPreTrainedModel):
    def __init__(self, config, num_labels=20):
        super(BertForMultiLabelClassification, self).__init__(config)
        self.num_labels = num_labels
        self.bert = BertModel(config)
        self.num_capsule = 10
        self.dim_capsule = 16
        self.caps = Caps_Layer(batch_size=12,
                               input_dim_capsule=config.hidden_size,
                               num_capsule=10,
                               dim_capsule=16,
                               routings=5)
        self.dense = nn.Linear(self.num_capsule * self.dim_capsule, num_labels)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, num_labels)
        self.apply(self.init_bert_weights)

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                labels=None):
        last_output, pooled_output = self.bert(input_ids,
                                               token_type_ids,
                                               attention_mask,
                                               output_all_encoded_layers=False)
        # last_output = torch.cuda.FloatTensor(last_output)
        # attention_mask = torch.cuda.FloatTensor(attention_mask)
        pooled_output = torch.sum(
            last_output * attention_mask.float().unsqueeze(2),
            dim=1) / torch.sum(attention_mask.float(), dim=1, keepdim=True)
        '''
        batch_size = input_ids.size(0)
        caps_output = self.caps(last_output)  # (batch_size, num_capsule, dim_capsule)
        caps_output = caps_output.view(batch_size, -1)  # (batch_size, num_capsule*dim_capsule)
        caps_dropout = self.dropout(caps_output)
        logits = self.dense(caps_dropout)
        '''

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        if labels is not None:
            # loss_fct = BCEWithLogitsLoss()
            # loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
            alpha = 0.75
            gamma = 3

            # focal loss

            x = logits.view(-1, self.num_labels)
            t = labels.view(-1, self.num_labels)
            '''
            p = x.sigmoid()
            pt = p*t + (1-p)*(1-t)
            w = alpha*t + (1-alpha)*(1-t)
            w = w*(1-pt).pow(gamma)
            # return F.binary_cross_entropy_with_logits(x, t, w, size_average=False)
            return binary_cross_entropy(x, t, weight=w, smooth_eps=0.1, from_logits=True)
            '''
            loss_fct = FocalLoss(logits=True)
            loss = loss_fct(x, t)
            return loss
        else:
            return logits

    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True
class BertForMultiLabelSequenceClassification(BertPreTrainedModel):
    """BERT model for classification.
    This module is composed of the BERT model with a linear layer on top of
    the pooled output.
    Params:
        `config`: a BertConfig class instance with the configuration to build a new model.
        `num_labels`: the number of classes for the classifier. Default = 2.
    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
            with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
            `extract_features.py`, `run_classifier.py` and `run_squad.py`)
        `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
            types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
            a `sentence B` token (see BERT paper for more details).
        `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
            selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
            input sequence length in the current batch. It's the mask that we typically use for attention when
            a batch has varying length sentences.
        `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
            with indices selected in [0, ..., num_labels].
    Outputs:
        if `labels` is not `None`:
            Outputs the CrossEntropy classification loss of the output with the labels.
        if `labels` is `None`:
            Outputs the classification logits of shape [batch_size, num_labels].
    Example usage:
    ```python
    # Already been converted into WordPiece token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
    input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
    token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
    config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
        num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
    num_labels = 2
    model = BertForSequenceClassification(config, num_labels)
    logits = model(input_ids, token_type_ids, input_mask)
    ```
    """
    def __init__(self, config, num_labels=1):
        super(BertForMultiLabelSequenceClassification, self).__init__(config)
        self.num_labels = num_labels
        self.bert = BertModel(config)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.classifier = torch.nn.Linear(config.hidden_size, num_labels)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        _, pooled_output = self.bert(input_ids,
                                     token_type_ids,
                                     attention_mask,
                                     output_all_encoded_layers=False)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True
class BertQAYesnoHierarchicalHardFP16(BertPreTrainedModel):
    """
    Hard:
    Hard attention, using gumbel softmax of reinforcement learning.
    """
    def __init__(self,
                 config,
                 evidence_lambda=0.8,
                 use_gumbel=True,
                 freeze_bert=False):
        super(BertQAYesnoHierarchicalHardFP16, self).__init__(config)
        logger.info(f'The model {self.__class__.__name__} is loading...')
        logger.info(f'The coefficient of evidence loss is {evidence_lambda}')
        logger.info(f'Use gumbel: {use_gumbel}')
        logger.info(f'If freeze BERT\'s parameters: {freeze_bert} ')

        layers.set_seq_dropout(True)
        layers.set_my_dropout_prob(config.hidden_dropout_prob)
        rep_layers.set_seq_dropout(True)
        rep_layers.set_my_dropout_prob(config.hidden_dropout_prob)

        self.bert = BertModel(config)

        if freeze_bert:
            for p in self.bert.parameters():
                p.requires_grad = False

        self.doc_sen_self_attn = rep_layers.LinearSelfAttention(
            config.hidden_size)
        self.que_self_attn = rep_layers.LinearSelfAttention(config.hidden_size)

        self.word_similarity = layers.AttentionScore(config.hidden_size,
                                                     250,
                                                     do_similarity=False)
        self.vector_similarity = layers.AttentionScore(config.hidden_size,
                                                       250,
                                                       do_similarity=False)

        self.yesno_predictor = nn.Linear(config.hidden_size * 2, 3)
        self.evidence_lam = evidence_lambda
        self.use_gumbel = use_gumbel

        self.apply(self.init_bert_weights)

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                answer_choice=None,
                sentence_span_list=None,
                sentence_ids=None):
        sequence_output, _ = self.bert(input_ids,
                                       token_type_ids,
                                       attention_mask,
                                       output_all_encoded_layers=False)

        # mask: 1 for masked value and 0 for true value
        # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask)
        doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \
            rep_layers.split_doc_sen_que(sequence_output, token_type_ids, attention_mask, sentence_span_list)

        batch, max_sen, doc_len = doc_sen_mask.size()

        que_vec = self.que_self_attn(que, que_mask).view(batch, 1, -1)

        doc = doc_sen.reshape(batch, max_sen * doc_len, -1)
        # [batch, max_sen, doc_len] -> [batch * max_sen, doc_len]
        word_sim = self.word_similarity(que_vec,
                                        doc).view(batch * max_sen, doc_len)

        doc = doc_sen.reshape(batch * max_sen, doc_len, -1)
        doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len)
        word_hidden = rep_layers.masked_softmax(word_sim, doc_mask,
                                                dim=1).unsqueeze(1).bmm(doc)

        word_hidden = word_hidden.view(batch, max_sen, -1)

        doc_vecs = self.doc_sen_self_attn(doc,
                                          doc_mask).view(batch, max_sen, -1)

        # [batch, 1, h]
        # sentence_hidden = self.vector_similarity(que_vec, doc_vecs, x2_mask=sentence_mask, x3=word_hidden).squeeze(1)
        # [batch, 1, max_sen]
        sentence_sim = self.vector_similarity(que_vec, doc_vecs)
        sentence_hidden = self.hard_sample(
            sentence_sim,
            use_gumbel=self.use_gumbel,
            dim=-1,
            hard=True,
            mask=sentence_mask).bmm(word_hidden).squeeze(1)

        yesno_logits = self.yesno_predictor(
            torch.cat([sentence_hidden, que_vec.squeeze(1)], dim=1))

        sentence_scores = rep_layers.masked_softmax(sentence_sim,
                                                    sentence_mask,
                                                    dim=-1).squeeze_(1)
        output_dict = {
            'yesno_logits':
            torch.softmax(yesno_logits, dim=-1).detach().cpu().float(),
            'sentence_logits':
            sentence_scores
        }
        loss = 0
        if answer_choice is not None:
            choice_loss = F.cross_entropy(yesno_logits,
                                          answer_choice,
                                          ignore_index=-1)
            loss += choice_loss
        # if sentence_ids is not None:
        #     log_sentence_sim = rep_layers.masked_log_softmax(sentence_sim.squeeze(1), sentence_mask, dim=-1)
        #     sentence_loss = self.evidence_lam * F.nll_loss(log_sentence_sim, sentence_ids, ignore_index=-1)
        #     loss += sentence_loss
        output_dict['loss'] = loss
        return output_dict

    def hard_sample(self, logits, use_gumbel, dim=-1, hard=True, mask=None):
        if use_gumbel:
            if self.training:
                probs = rep_layers.gumbel_softmax(logits,
                                                  mask=mask,
                                                  hard=hard,
                                                  dim=dim)
                return probs
            else:
                probs = rep_layers.masked_softmax(logits, mask, dim=dim)
                index = probs.float().max(dim, keepdim=True)[1]
                y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
                return y_hard
        else:
            pass
Exemple #24
0
class SANBertNetwork(nn.Module):
    def __init__(self, opt, bert_config=None,
                 use_parse=False, embedding_matrix=None, token2idx=None, stx_parse_dim=None, unked_words=None,
                 use_generic_features=False, num_generic_features=None, use_domain_features=False, num_domain_features=None, feature_dim=None):
        super(SANBertNetwork, self).__init__()
        self.dropout_list = []
        self.bert_config = BertConfig.from_dict(opt)
        self.bert = BertModel(self.bert_config)
        if opt['update_bert_opt'] > 0:
            for p in self.bert.parameters():
                p.requires_grad = False
        mem_size = self.bert_config.hidden_size
        self.scoring_list = nn.ModuleList()
        labels = [int(ls) for ls in opt['label_size'].split(',')]
        task_dropout_p = opt['tasks_dropout_p']
        self.bert_pooler = None

        self.use_parse = use_parse
        self.stx_parse_dim = stx_parse_dim
        self.use_generic_features = use_generic_features
        self.use_domain_features = use_domain_features

        clf_dim = self.bert_config.hidden_size
        if self.use_parse:
            self.treelstm = BinaryTreeLSTM(self.stx_parse_dim, embedding_matrix.clone(), token2idx, unked_words=unked_words)
            parse_clf_dim = self.stx_parse_dim * 2
            clf_dim += parse_clf_dim
            self.parse_clf = nn.Linear(parse_clf_dim, labels[0])
        if self.use_generic_features:
            self.generic_feature_proj = nn.Linear(num_generic_features, num_generic_features * feature_dim)
            generic_feature_clf_dim = num_generic_features * feature_dim
            clf_dim += generic_feature_clf_dim
            self.generic_feature_clf = nn.Linear(generic_feature_clf_dim, labels[0])
        if self.use_domain_features:
            self.domain_feature_proj = nn.Linear(num_domain_features, num_domain_features * feature_dim)
            domain_feature_clf_dim = num_domain_features * feature_dim
            clf_dim += domain_feature_clf_dim
            self.domain_feature_clf = nn.Linear(domain_feature_clf_dim, labels[0])

        assert len(labels) == 1
        for task, lab in enumerate(labels):
            dropout = DropoutWrapper(task_dropout_p[task], opt['vb_dropout'])
            self.dropout_list.append(dropout)
            out_proj = nn.Linear(self.bert_config.hidden_size, lab)
            self.scoring_list.append(out_proj)

        self.opt = opt
        self._my_init()
        self.set_embed(opt)
        if embedding_matrix is not None and self.use_parse:
            self.treelstm.embedding.weight = nn.Parameter(embedding_matrix)  # set again b/c self._my_init() overwrites it

    def _my_init(self):
        def init_weights(module):
            if isinstance(module, (nn.Linear, nn.Embedding)):
                # Slightly different from the TF version which uses truncated_normal for initialization
                # cf https://github.com/pytorch/pytorch/pull/5617
                module.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range * self.opt['init_ratio'])
            elif isinstance(module, BertLayerNorm):
                # Slightly different from the BERT pytorch version, which should be a bug.
                # Note that it only affects on training from scratch. For detailed discussions, please contact xiaodl@.
                # Layer normalization (https://arxiv.org/abs/1607.06450)
                # support both old/latest version
                if 'beta' in dir(module) and 'gamma' in dir(module):
                    module.beta.data.zero_()
                    module.gamma.data.fill_(1.0)
                else:
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
            if isinstance(module, nn.Linear):
                module.bias.data.zero_()
        self.apply(init_weights)

    def nbert_layer(self):
        return len(self.bert.encoder.layer)

    def freeze_layers(self, max_n):
        assert max_n < self.nbert_layer()
        for i in range(0, max_n):
            self.freeze_layer(i)

    def freeze_layer(self, n):
        assert n < self.nbert_layer()
        layer = self.bert.encoder.layer[n]
        for p in layer.parameters():
            p.requires_grad = False

    def set_embed(self, opt):
        bert_embeddings = self.bert.embeddings
        emb_opt = opt['embedding_opt']
        if emb_opt == 1:
            for p in bert_embeddings.word_embeddings.parameters():
                p.requires_grad = False
        elif emb_opt == 2:
            for p in bert_embeddings.position_embeddings.parameters():
                p.requires_grad = False
        elif emb_opt == 3:
            for p in bert_embeddings.token_type_embeddings.parameters():
                p.requires_grad = False
        elif emb_opt == 4:
            for p in bert_embeddings.token_type_embeddings.parameters():
                p.requires_grad = False
            for p in bert_embeddings.position_embeddings.parameters():
                p.requires_grad = False

    def forward(self, input_ids, token_type_ids, attention_mask, premise_mask=None, hyp_mask=None, task_id=0,
                bin_parse_as=None, bin_parse_bs=None, parse_as_mask=None, parse_bs_mask=None,
                generic_features=None, domain_features=None):
        all_encoder_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
        sequence_output = all_encoder_layers[-1]
        if self.bert_pooler is not None:
            pooled_output = self.bert_pooler(sequence_output)
        pooled_output = self.dropout_list[task_id](pooled_output)
        logits = self.scoring_list[task_id](pooled_output)

        if self.use_parse:
            parse_embeddings = torch.FloatTensor(len(input_ids), self.stx_parse_dim * 2).to(input_ids.device)
            assert len(bin_parse_as) == len(bin_parse_bs) == len(parse_as_mask) == len(parse_bs_mask)
            for i, (parse_a, parse_b, parse_a_mask, parse_b_mask) in enumerate(zip(bin_parse_as, bin_parse_bs, parse_as_mask, parse_bs_mask)):
                parse_a = parse_a[:parse_a_mask.sum()]
                parse_b = parse_b[:parse_b_mask.sum()]
                t = Tree.from_char_indices(parse_a)
                parse_embeddings[i,:self.stx_parse_dim] = self.treelstm(t)[1]
                t = Tree.from_char_indices(parse_b)
                parse_embeddings[i,self.stx_parse_dim:] = self.treelstm(t)[1]
            logits += self.parse_clf(self.dropout_list[task_id](parse_embeddings))

        if self.use_generic_features:
            # features: bsz * n_features
            generic_feature_embeddings = F.relu(self.generic_feature_proj(generic_features))
            logits += self.generic_feature_clf(self.dropout_list[task_id](generic_feature_embeddings))

        if self.use_domain_features:
            # features: bsz * n_features
            domain_feature_embeddings = F.relu(self.domain_feature_proj(domain_features))
            logits += self.domain_feature_clf(self.dropout_list[task_id](domain_feature_embeddings))

        return logits
Exemple #25
0
class SANBertNetwork(nn.Module):
    def __init__(self, opt, bert_config=None):
        super(SANBertNetwork, self).__init__()
        self.dropout_list = nn.ModuleList()
        self.encoder_type = opt['encoder_type']
        if opt['encoder_type'] == EncoderModelType.ROBERTA:
            from fairseq.models.roberta import RobertaModel
            self.bert = RobertaModel.from_pretrained(opt['init_checkpoint'])
            hidden_size = self.bert.args.encoder_embed_dim
            self.pooler = LinearPooler(hidden_size)
        else: 
            self.bert_config = BertConfig.from_dict(opt)
            self.bert = BertModel(self.bert_config)
            hidden_size = self.bert_config.hidden_size

        if opt.get('dump_feature', False):
            self.opt = opt
            return
        if opt['update_bert_opt'] > 0:
            for p in self.bert.parameters():
                p.requires_grad = False
        self.decoder_opt = opt['answer_opt']
        self.task_types = opt["task_types"]
        self.scoring_list = nn.ModuleList()
        labels = [int(ls) for ls in opt['label_size'].split(',')]
        task_dropout_p = opt['tasks_dropout_p']

        for task, lab in enumerate(labels):
            decoder_opt = self.decoder_opt[task]
            task_type = self.task_types[task]
            dropout = DropoutWrapper(task_dropout_p[task], opt['vb_dropout'])
            self.dropout_list.append(dropout)
            if task_type == TaskType.Span:
                assert decoder_opt != 1
                out_proj = nn.Linear(hidden_size, 2)
            elif task_type == TaskType.SeqenceLabeling:
                out_proj = nn.Linear(hidden_size, lab)
            else:
                if decoder_opt == 1:
                    out_proj = SANClassifier(hidden_size, hidden_size, lab, opt, prefix='answer', dropout=dropout)
                else:
                    out_proj = nn.Linear(hidden_size, lab)
            self.scoring_list.append(out_proj)

        self.opt = opt
        self._my_init()

    def _my_init(self):
        def init_weights(module):
            if isinstance(module, (nn.Linear, nn.Embedding)):
                # Slightly different from the TF version which uses truncated_normal for initialization
                # cf https://github.com/pytorch/pytorch/pull/5617
                module.weight.data.normal_(mean=0.0, std=0.02 * self.opt['init_ratio'])
            elif isinstance(module, BertLayerNorm):
                # Slightly different from the BERT pytorch version, which should be a bug.
                # Note that it only affects on training from scratch. For detailed discussions, please contact xiaodl@.
                # Layer normalization (https://arxiv.org/abs/1607.06450)
                # support both old/latest version
                if 'beta' in dir(module) and 'gamma' in dir(module):
                    module.beta.data.zero_()
                    module.gamma.data.fill_(1.0)
                else:
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
            if isinstance(module, nn.Linear):
                module.bias.data.zero_()

        self.apply(init_weights)

    def forward(self, input_ids, token_type_ids, attention_mask, premise_mask=None, hyp_mask=None, task_id=0):
        if attention_mask is not None and attention_mask.dtype == torch.uint8:
            attention_mask = attention_mask.bool()

        if premise_mask is not None and premise_mask.dtype == torch.uint8:
            premise_mask = premise_mask.bool()

        if hyp_mask is not None and hyp_mask.dtype == torch.uint8:
            hyp_mask = hyp_mask.bool()

        if self.encoder_type == EncoderModelType.ROBERTA:
            sequence_output = self.bert.extract_features(input_ids)
            pooled_output = self.pooler(sequence_output)
        else:
            all_encoder_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
            sequence_output = all_encoder_layers[-1]

        decoder_opt = self.decoder_opt[task_id]
        task_type = self.task_types[task_id]
        if task_type == TaskType.Span:
            assert decoder_opt != 1
            sequence_output = self.dropout_list[task_id](sequence_output)
            logits = self.scoring_list[task_id](sequence_output)
            start_scores, end_scores = logits.split(1, dim=-1)
            start_scores = start_scores.squeeze(-1)
            end_scores = end_scores.squeeze(-1)
            return start_scores, end_scores
        elif task_type == TaskType.SeqenceLabeling:
            pooled_output = all_encoder_layers[-1]
            pooled_output = self.dropout_list[task_id](pooled_output)
            pooled_output = pooled_output.contiguous().view(-1, pooled_output.size(2))
            logits = self.scoring_list[task_id](pooled_output)
            return logits
        else:
            if decoder_opt == 1:
                max_query = hyp_mask.size(1)
                assert max_query > 0
                assert premise_mask is not None
                assert hyp_mask is not None
                hyp_mem = sequence_output[:, :max_query, :]
                logits = self.scoring_list[task_id](sequence_output, hyp_mem, premise_mask, hyp_mask)
            else:
                pooled_output = self.dropout_list[task_id](pooled_output)
                logits = self.scoring_list[task_id](pooled_output)
            return logits
class BertQAYesnoHierarchicalSingle(BertPreTrainedModel):
    """
    BertForQuestionAnsweringForYesNo

    Model Hierarchical Attention:
        - Use Hierarchical attention module to predict Non/Yes/No.
        - Add supervised to sentence attention.

    Sentence level model.
    """
    def __init__(self,
                 config,
                 evidence_lambda=0.8,
                 negative_lambda=1.0,
                 add_entropy: bool = False,
                 fix_bert: bool = False):
        super(BertQAYesnoHierarchicalSingle, self).__init__(config)
        logger.info(f'The model {self.__class__.__name__} is loading...')
        logger.info(f'The coefficient of evidence loss is {evidence_lambda}')
        logger.info(
            f'The coefficient of negative samples loss is {negative_lambda}')
        logger.info(f'Fix parameters of BERT: {fix_bert}')
        logger.info(f'Add entropy loss: {add_entropy}')
        # logger.info(f'Use bidirectional attention before summarizing vectors: {bi_attention}')

        layers.set_seq_dropout(True)
        layers.set_my_dropout_prob(config.hidden_dropout_prob)

        self.bert = BertModel(config)
        # self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # self.answer_choice = nn.Linear(config.hidden_size, 2)
        if fix_bert:
            for param in self.bert.parameters():
                param.requires_grad = False

        self.doc_sen_self_attn = layers.LinearSelfAttnAllennlp(
            config.hidden_size)
        self.que_self_attn = layers.LinearSelfAttn(config.hidden_size)

        self.word_similarity = layers.AttentionScore(config.hidden_size,
                                                     250,
                                                     do_similarity=False)
        self.vector_similarity = layers.AttentionScore(config.hidden_size,
                                                       250,
                                                       do_similarity=False)

        # self.doc_sen_encoder = layers.StackedBRNN(config.hidden_size, 125, num_layers=1)

        # self.yesno_predictor = nn.Linear(config.hidden_size, 2)
        self.yesno_predictor = nn.Linear(config.hidden_size * 2, 3)
        self.evidence_lam = evidence_lambda
        self.negative_lam = negative_lambda
        self.add_entropy = add_entropy

        self.apply(self.init_bert_weights)

    def forward(self,
                ques_input_ids,
                ques_input_mask,
                pass_input_ids,
                pass_input_mask,
                answer_choice=None,
                sentence_ids=None,
                sentence_label=None):

        # Encoding question
        q_len = ques_input_ids.size(1)
        question, _ = self.bert(ques_input_ids,
                                token_type_ids=None,
                                attention_mask=ques_input_mask,
                                output_all_encoded_layers=False)
        # Encoding passage
        batch, max_sen_num, p_len = pass_input_ids.size()
        pass_input_ids = pass_input_ids.reshape(batch * max_sen_num, p_len)
        pass_input_mask = pass_input_mask.reshape(batch * max_sen_num, p_len)
        passage, _ = self.bert(pass_input_ids,
                               token_type_ids=None,
                               attention_mask=pass_input_mask,
                               output_all_encoded_layers=False)

        que_mask = (1 - ques_input_mask).byte()
        que_vec = layers.weighted_avg(question,
                                      self.que_self_attn(question,
                                                         que_mask)).view(
                                                             batch, 1, -1)

        doc = passage.reshape(batch, max_sen_num * p_len, -1)
        # [batch, max_sen, doc_len] -> [batch * max_sen, doc_len]
        word_sim = self.word_similarity(que_vec,
                                        doc).view(batch * max_sen_num, p_len)

        # doc_mask = 1 - pass_input_mask
        doc_mask = pass_input_ids  # 1 for true value and 0 for mask
        # [batch * max_sen, doc_len] -> [batch * max_sen, 1, doc_len] -> [batch * max_sen, 1, h]
        word_hidden = masked_softmax(word_sim, doc_mask,
                                     dim=1).unsqueeze(1).bmm(passage)
        word_hidden = word_hidden.view(batch, max_sen_num, -1)

        sentence_mask = pass_input_mask.reshape(
            batch, max_sen_num, p_len).sum(dim=-1).ge(1.0).float()

        # 1 - doc_mask: 0 for true value and 1 for mask
        doc_vecs = layers.weighted_avg(
            passage,
            self.doc_sen_self_attn(passage,
                                   1 - doc_mask)).view(batch, max_sen_num, -1)

        # [batch, 1, max_sen]
        sentence_sim = self.vector_similarity(que_vec, doc_vecs)
        # sentence_scores = masked_softmax(sentence_sim, 1 - sentence_mask)
        sentence_scores = masked_softmax(
            sentence_sim, sentence_mask)  # 1 for true value and 0 for mask
        sentence_hidden = sentence_scores.bmm(word_hidden).squeeze(1)

        yesno_logits = self.yesno_predictor(
            torch.cat([sentence_hidden, que_vec.squeeze(1)], dim=1))

        sentence_scores = sentence_scores.squeeze(1)
        max_sentence_score = sentence_scores.max(dim=-1)
        output_dict = {
            'yesno_logits': yesno_logits,
            'sentence_logits': sentence_scores,
            'max_weight': max_sentence_score[0],
            'max_weight_index': max_sentence_score[1]
        }
        loss = 0
        if answer_choice is not None:
            choice_loss = F.cross_entropy(yesno_logits,
                                          answer_choice,
                                          ignore_index=-1)
            loss += choice_loss
        if sentence_ids is not None:
            log_sentence_sim = masked_log_softmax(sentence_sim.squeeze(1),
                                                  sentence_mask,
                                                  dim=-1)
            sentence_loss = self.evidence_lam * F.nll_loss(
                log_sentence_sim, sentence_ids, ignore_index=-1)
            loss += sentence_loss
            if self.add_entropy:
                no_evidence_mask = (sentence_ids != -1)
                entropy = layers.get_masked_entropy(sentence_scores,
                                                    mask=no_evidence_mask)
                loss += self.evidence_lam * entropy
        if sentence_label is not None:
            # sentence_label: batch * List[k]
            # [batch, max_sen]
            # log_sentence_sim = masked_log_softmax(sentence_sim.squeeze(1), 1 - sentence_mask, dim=-1)
            sentence_prob = 1 - sentence_scores
            log_sentence_sim = -torch.log(sentence_prob + 1e-15)
            negative_loss = 0
            for b in range(batch):
                for sen_id, k in enumerate(sentence_label[b]):
                    negative_loss += k * log_sentence_sim[b][sen_id]
            negative_loss /= batch
            loss += self.negative_lam * negative_loss
        output_dict['loss'] = loss
        return output_dict