def __init__(self, config, classes_num):
        super(ERENet, self).__init__(config, classes_num)

        print('spo_transformers_freeze')
        self.classes_num = classes_num

        # BERT model

        self.bert = BertModel(config)
        for p in self.bert.parameters():
            p.requires_grad = False

        self.lstm_encoder = SentenceEncoder(config.hidden_size,
                                            config.hidden_size // 2)

        # self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=config.hidden_size,
        #                                      padding_idx=0)
        self.LayerNorm = ConditionalLayerNorm(config.hidden_size,
                                              eps=config.layer_norm_eps)

        # pointer net work
        self.po_dense = nn.Linear(config.hidden_size, self.classes_num * 2)
        self.subject_dense = nn.Linear(config.hidden_size, 2)
        self.loss_fct = nn.BCEWithLogitsLoss(reduction='none')

        # self.init_weights()
        self.apply(self.init_bert_weights)
Esempio n. 2
0
    def __init__(self, args, word_emb, ent_conf, spo_conf):
        print('mhs using only char2v+w2v mixed  and word_emb is freeze ')
        super(ERENet, self).__init__()

        self.max_len = args.max_len

        self.word_emb = nn.Embedding.from_pretrained(torch.tensor(word_emb, dtype=torch.float32), freeze=True,
                                                     padding_idx=0)
        self.char_emb = nn.Embedding(num_embeddings=args.char_vocab_size, embedding_dim=args.char_emb_size,
                                     padding_idx=0)

        self.word_convert_char = nn.Linear(args.word_emb_size, args.char_emb_size, bias=False)

        self.classes_num = len(spo_conf)

        self.first_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
        # self.second_sentence_encoder = SentenceEncoder(args, args.hidden_size)
        # self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=config.hidden_size,
        #                                      padding_idx=0)
        self.encoder_layer = TransformerEncoderLayer(args.hidden_size * 2, nhead=3)
        self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=1)
        self.LayerNorm = ConditionalLayerNorm(args.hidden_size * 2, eps=1e-12)
        # self.subject_dense = nn.Linear(args.hidden_size * 2, 2)

        self.ent_emission = nn.Linear(args.hidden_size * 2, len(ent_conf))
        self.ent_crf = CRF(len(ent_conf), batch_first=True)
        self.emission = nn.Linear(args.hidden_size * 2, len(spo_conf))
        self.crf = CRF(len(spo_conf), batch_first=True)
        self.loss_fct = nn.BCEWithLogitsLoss(reduction='none')
Esempio n. 3
0
    def __init__(self, args, word_emb, spo_conf):
        print('drug extract using only char2v')
        super(ERENet, self).__init__()

        # self.word_emb = nn.Embedding.from_pretrained(torch.tensor(word_emb, dtype=torch.float32), freeze=True,
        #                                              padding_idx=0)
        self.char_emb = nn.Embedding(num_embeddings=args.char_vocab_size,
                                     embedding_dim=args.char_emb_size,
                                     padding_idx=0)

        # self.word_convert_char = nn.Linear(args.word_emb_size, args.char_emb_size, bias=False)

        self.classes_num = len(spo_conf)

        self.first_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
        self.second_sentence_encoder = SentenceEncoder(args,
                                                       args.hidden_size * 2)
        # self.second_sentence_encoder = SentenceEncoder(args, args.hidden_size)
        self.token_entity_emb = nn.Embedding(num_embeddings=2,
                                             embedding_dim=args.hidden_size *
                                             2,
                                             padding_idx=0)
        self.encoder_layer = TransformerEncoderLayer(args.hidden_size * 2,
                                                     nhead=3)
        self.transformer_encoder = TransformerEncoder(self.encoder_layer,
                                                      num_layers=1)
        self.LayerNorm = ConditionalLayerNorm(args.hidden_size * 2, eps=1e-12)
        # pointer net work
        self.po_dense = nn.Linear(args.hidden_size * 2, self.classes_num * 2)
        self.subject_dense = nn.Linear(args.hidden_size * 2, 2)
        self.loss_fct = nn.BCEWithLogitsLoss(reduction='none')
Esempio n. 4
0
    def __init__(self, config, classes_num):
        super(ERENet, self).__init__(config, classes_num)
        self.classes_num = classes_num
        self.bert = BertModel(config)
        self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=config.hidden_size,
                                             padding_idx=0)
        # self.encoder_layer = TransformerEncoderLayer(config.hidden_size, nhead=4)
        # self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=1)
        self.LayerNorm = ConditionalLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # pointer net work
        self.po_dense = nn.Linear(config.hidden_size, self.classes_num * 2)
        self.subject_dense = nn.Linear(config.hidden_size, 2)
        self.loss_fct = nn.BCEWithLogitsLoss(reduction='none')

        self.apply(self.init_bert_weights)
Esempio n. 5
0
    def __init__(self, config, classes_num):
        super(ERENet, self).__init__(config, classes_num)
        self.classes_num = classes_num

        # BERT model
        self.bert = BertModel(config)
        self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=config.hidden_size,
                                             padding_idx=0)
        self.LayerNorm = ConditionalLayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        # pointer net work
        self.po_dense = nn.Linear(config.hidden_size, self.classes_num * 2)
        self.subject_dense = nn.Linear(config.hidden_size, 2)
        self.loss_fct = nn.BCEWithLogitsLoss(reduction='none')

        self.init_weights()
Esempio n. 6
0
    def __init__(self, config, classes_num):
        super(EntityRelationExtractionNet, self).__init__(config, classes_num)
        self.classes_num = classes_num

        self.bert = BertModel(config)
        self.token_entity_emb = nn.Embedding(num_embeddings=2,
                                             embedding_dim=config.hidden_size,
                                             padding_idx=0)

        self.LayerNorm = ConditionalLayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        # Pointer Net Work
        self.po_dense = nn.Linear(config.hidden_size, self.classes_num * 2)
        self.subject_dense = nn.Linear(config.hidden_size, 2)
        # 二进制交叉熵损失函数:sigmoid操作和与BCELoss集合
        self.loss_fct = nn.BCEWithLogitsLoss(reduction='none')

        self.init_weights()