class NER(BaseModel):
    def __init__(self,
                 n_tag,
                 tag2id,
                 vocab_size,
                 type_vocab_size,
                 max_position_embeddings,
                 n_layers,
                 h_dim,
                 n_heads,
                 pf_dim,
                 layer_norm_eps,
                 dropout,
                 device,
                 hid_act,
                 pooler_act,
                 add_pooler,
                 checkpoint=None,
                 **kwargs):
        super().__init__()
        self.bert = Bert(vocab_size, type_vocab_size, max_position_embeddings,
                         n_layers, h_dim, n_heads, pf_dim, layer_norm_eps,
                         dropout, device, hid_act, pooler_act, add_pooler)

        if checkpoint is not None:
            new_state_dict = OrderedDict()
            for k1, k2 in zip(self.bert.state_dict().keys(),
                              checkpoint.keys()):
                new_state_dict[k1] = checkpoint[k2]

            self.bert.load_state_dict(new_state_dict)
            for param in self.bert.parameters():
                param.requires_grad = False

        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(h_dim, n_tag)
        self.crf = CRF(n_tag, tag2id, device)

    def forward(self,
                inputs,
                inputs_mask,
                token_type_ids,
                tags=None,
                position_ids=None):
        encoder_output, pooled_output = self.bert(inputs, inputs_mask,
                                                  token_type_ids, position_ids)
        encoder_output = self.dropout(encoder_output)
        logits = self.classifier(encoder_output)
        outputs = [logits]
        if tags is not None:
            loss = self.crf(logits, tags, inputs_mask)
            outputs = [-1 * loss] + outputs
        return outputs
Exemple #2
0
class ToEmbedding(BaseModel):
    def __init__(self,
                 embedding_dim,
                 vocab_size,
                 type_vocab_size,
                 max_position_embeddings,
                 n_layers,
                 h_dim,
                 n_heads,
                 pf_dim,
                 layer_norm_eps,
                 dropout,
                 device,
                 hid_act,
                 pooler_act,
                 add_pooler,
                 checkpoint=None,
                 **kwargs):
        super().__init__()

        self.bert = Bert(vocab_size, type_vocab_size, max_position_embeddings,
                         n_layers, h_dim, n_heads, pf_dim, layer_norm_eps,
                         dropout, device, hid_act, pooler_act, add_pooler)

        if checkpoint is not None:
            new_state_dict = OrderedDict()
            for k1, k2 in zip(self.bert.state_dict().keys(),
                              checkpoint.keys()):
                new_state_dict[k1] = checkpoint[k2]

            self.bert.load_state_dict(new_state_dict)
            for param in self.bert.parameters():
                param.requires_grad = False

        self.dense_layer = nn.Linear(h_dim, embedding_dim)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(embedding_dim, vocab_size)

    def forward(self, inputs, inputs_mask, token_type_ids, position_ids=None):
        encoder_output, pooled_output = self.bert(inputs, inputs_mask,
                                                  token_type_ids, position_ids)
        # encoder_output = self.dropout(encoder_output)
        output_embedding = self.dense_layer(encoder_output)
        output_embedding = self.dropout(output_embedding)
        output = self.classifier(output_embedding)
        return output
Exemple #3
0
class MultiInstanceLearning(BaseModel):
    def __init__(self,
                 batch_size,
                 n_classes,
                 max_seq_length,
                 vocab_size,
                 type_vocab_size,
                 max_position_embeddings,
                 embedding_dim,
                 pos_dim,
                 kernel_size,
                 padding_size,
                 pcnn_h_dim,
                 n_layers,
                 h_dim,
                 n_heads,
                 pf_dim,
                 layer_norm_eps,
                 dropout,
                 device,
                 hid_act,
                 pooler_act,
                 add_pooler,
                 bert_checkpoint=None,
                 dense_layer_checkpoint=None,
                 **kwargs):
        """
                    n_classes, vocab_size, 
                    type_vocab_size, max_position_embeddings,
                    n_layers, h_dim, n_heads, pf_dim, layer_norm_eps, 
                    dropout, device, hid_act, pooler_act, add_pooler, checkpoint=None,
        """
        super().__init__()
        self.bert = Bert(vocab_size, type_vocab_size, max_position_embeddings,
                         n_layers, h_dim, n_heads, pf_dim, layer_norm_eps,
                         dropout, device, hid_act, pooler_act, add_pooler)

        self.dense_layer = nn.Linear(h_dim, embedding_dim)
        if bert_checkpoint is not None:
            new_state_dict = OrderedDict()
            for k1, k2 in zip(self.bert.state_dict().keys(),
                              bert_checkpoint.keys()):
                new_state_dict[k1] = bert_checkpoint[k2]

            self.bert.load_state_dict(new_state_dict)
            for param in self.bert.parameters():
                param.requires_grad = False

        if dense_layer_checkpoint is not None:
            self.dense_layer.load_state_dict(
                dense_layer_checkpoint['state_dict'][0])
            for param in self.dense_layer.parameters():
                param.requires_grad = False

        self.pcnn = PCNN(max_seq_length, pcnn_h_dim, vocab_size, embedding_dim,
                         pos_dim, kernel_size, padding_size, dropout, hid_act,
                         device)

        self.n_classes = n_classes
        self.batch_size = batch_size
        self.device = device
        self.fc = nn.Linear(self.pcnn.h_dim, n_classes)
        self.softmax = nn.Softmax(-1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, scope, inputs, bag_size=0, label=None):
        tokens_id, tokens_mask, segment_id, pos1, pos2, pcnn_mask = inputs
        encoder_output, _ = self.bert(tokens_id, tokens_mask, segment_id)
        encoder_output = self.dropout(encoder_output)
        output = self.dense_layer(encoder_output)

        if bag_size > 0:
            # print("before token shape {}".format(token.shape))
            tokens_id = tokens_id.view(-1, tokens_id.size(-1))
            # print("after token shape {}".format(token.shape))
            pos1 = pos1.view(-1, pos1.size(-1))
            pos2 = pos2.view(-1, pos2.size(-1))
            pcnn_mask = pcnn_mask.view(-1, pcnn_mask.size(-1))
        else:
            # print("before token shape {}".format(token.shape))
            begin, end = scope[0][0], scope[-1][1]
            tokens_id = tokens_id[begin:end, :].view(-1, tokens_id.size(-1))
            # print("after token shape {}".format(token.shape))
            pos1 = pos1[begin:end, :].view(-1, pos1.size(-1))
            pos2 = pos2[begin:end, :].view(-1, pos2.size(-1))
            pcnn_mask = pcnn_mask[begin:end, :].view(-1, pcnn_mask.size(-1))
            scope = torch.sub(scope, torch.zeros_like(scope).fill_(begin))

        rep = self.pcnn(output, pos1, pos2, pcnn_mask)  # (nsum, H)

        if label is not None:
            if bag_size == 0:
                bag_rep = []
                query = torch.zeros((rep.size(0))).long().to(self.device)
                for i in range(len(scope)):
                    query[scope[i][0]:scope[i][1]] = label[i]
                att_mat = self.fc.weight[query]  # (nsum, H)
                att_score = (rep * att_mat).sum(-1)  # (nsum)

                for i in range(len(scope)):
                    bag_mat = rep[scope[i][0]:scope[i][1]]  # (n, H)
                    softmax_att_score = self.softmax(
                        att_score[scope[i][0]:scope[i][1]])  # (n)
                    bag_rep.append(
                        (softmax_att_score.unsqueeze(-1) *
                         bag_mat).sum(0))  # (n, 1) * (n, H) -> (n, H) -> (H)
                bag_rep = torch.stack(bag_rep, 0)  # (B, H)
            else:
                batch_size = label.size(0)
                query = label.unsqueeze(1)  # (B, 1)
                att_mat = self.fc.weight[query]  # (B, 1, H)
                rep = rep.view(batch_size, bag_size, -1)
                att_score = (rep * att_mat).sum(-1)  # (B, bag)
                softmax_att_score = self.softmax(att_score)  # (B, bag)
                bag_rep = (softmax_att_score.unsqueeze(-1) * rep).sum(
                    1)  # (B, bag, 1) * (B, bag, H) -> (B, bag, H) -> (B, H)
            bag_rep = self.dropout(bag_rep)
            bag_logits = self.fc(bag_rep)  # (B, N)
        else:
            if bag_size == 0:
                bag_logits = []
                att_score = torch.matmul(rep, self.fc.weight.transpose(
                    0, 1))  # (nsum, H) * (H, N) -> (nsum, N)
                for i in range(len(scope)):
                    bag_mat = rep[scope[i][0]:scope[i][1]]  # (n, H)
                    softmax_att_score = self.softmax(
                        att_score[scope[i][0]:scope[i][1]].transpose(
                            0, 1))  # (N, (softmax)n)
                    rep_for_each_rel = torch.matmul(
                        softmax_att_score,
                        bag_mat)  # (N, n) * (n, H) -> (N, H)
                    logit_for_each_rel = self.softmax(
                        self.fc(rep_for_each_rel))  # ((each rel)N, (logit)N)
                    logit_for_each_rel = logit_for_each_rel.diag()  # (N)
                    bag_logits.append(logit_for_each_rel)
                bag_logits = torch.stack(bag_logits, 0)  # after **softmax**
            else:
                batch_size = rep.size(0) // bag_size
                att_score = torch.matmul(rep, self.fc.weight.transpose(
                    0, 1))  # (nsum, H) * (H, N) -> (nsum, N)
                att_score = att_score.view(batch_size, bag_size,
                                           -1)  # (B, bag, N)
                rep = rep.view(batch_size, bag_size, -1)  # (B, bag, H)
                softmax_att_score = self.softmax(att_score.transpose(
                    1, 2))  # (B, N, (softmax)bag)
                rep_for_each_rel = torch.matmul(
                    softmax_att_score,
                    rep)  # (B, N, bag) * (B, bag, H) -> (B, N, H)
                bag_logits = self.softmax(self.fc(rep_for_each_rel)).diagonal(
                    dim1=1, dim2=2)  # (B, (each rel)N)

        return bag_logits, att_score