示例#1
0
 def __init__(self,
              num_codes=0,
              num_linearization_embeddings=0,
              outdim=64,
              sentences_per_checkpoint=10,
              device1='cpu',
              device2='cpu',
              freeze_bert=True,
              num_code_embedding_types=1,
              dropout=.15):
     super(Model, self).__init__()
     self.num_codes = num_codes
     self.num_linearization_embeddings = num_linearization_embeddings
     self.clinical_bert_sentences = EncoderSentences(
         ClinicalBertWrapper,
         embedding_dim=outdim,
         truncate_tokens=50,
         truncate_sentences=1000,
         sentences_per_checkpoint=sentences_per_checkpoint,
         device=device1)
     if freeze_bert:
         self.freeze_bert()
     else:
         self.unfreeze_bert(dropout=dropout)
     self.code_embeddings = nn.Embedding(num_codes,
                                         outdim) if num_codes > 0 else None
     self.linearized_code_transformer = EncoderSentences(lambda : LinearizedCodesTransformer(num_linearization_embeddings), embedding_dim=outdim, truncate_tokens=50,
                                                         truncate_sentences=1000, sentences_per_checkpoint=sentences_per_checkpoint, device=device2)\
                                        if num_linearization_embeddings > 0 else None
     self.predict_targets = PredictTargets(outdim)
     self.linear2 = nn.Linear(
         outdim * num_code_embedding_types,
         outdim) if num_code_embedding_types > 1 else None
     self.device1 = device1
     self.device2 = device2
示例#2
0
class Model(nn.Module):
    def __init__(self,
                 num_codes,
                 outdim=64,
                 sentences_per_checkpoint=10,
                 device='cpu'):
        super(Model, self).__init__()
        self.num_codes = num_codes
        self.clinical_bert_sentences = EncoderSentences(
            ClinicalBertWrapper,
            pool_type="mean",
            truncate_tokens=50,
            truncate_sentences=1000,
            sentences_per_checkpoint=sentences_per_checkpoint,
            device=device)
        self.device = device

    def correct_devices(self):
        self.clinical_bert_sentences.correct_devices()

    def forward(self, article_sentences, article_sentences_lengths, num_codes,
                code_description, code_description_length):
        encodings, self_attentions, word_level_attentions = self.clinical_bert_sentences(
            article_sentences, article_sentences_lengths)
        # b, ns, nt = word_level_attentions.shape
        b, ns, nl, nh, nt, _ = self_attentions.shape
        traceback_word_level_attentions = ta(
            self_attentions.mean(3).view(b*ns, nl, nt, nt),
            attention_vecs=word_level_attentions.view(b*ns, 1, nt))\
            .view(b, ns, nt)
        code_embeddings = self.clinical_bert_sentences(
            code_description, code_description_length)[0]
        key_padding_mask = (
            article_sentences_lengths == 0)[:, :encodings.size(1)]
        sentence_level_attentions = ((code_embeddings.unsqueeze(-2) -
                                      encodings.unsqueeze(-3))**2).sum(-1)
        sentence_level_attentions = sentence_level_attentions / sentence_level_attentions.sum(
            2, keepdim=True)
        nq = code_description.shape[1]
        word_level_attentions = word_level_attentions\
            .view(b, 1, ns, nt)\
            .expand(b, nq, ns, nt)
        traceback_word_level_attentions = traceback_word_level_attentions\
            .view(b, 1, ns, nt)\
            .expand(b, nq, ns, nt)
        attention = word_level_attentions * sentence_level_attentions.unsqueeze(
            3)
        traceback_attention = traceback_word_level_attentions * sentence_level_attentions.unsqueeze(
            3)
        return_dict = dict(num_codes=num_codes,
                           attention=attention,
                           traceback_attention=traceback_attention,
                           article_sentences_lengths=article_sentences_lengths)
        return return_dict
示例#3
0
 def __init__(self,
              num_codes,
              outdim=64,
              sentences_per_checkpoint=10,
              device='cpu'):
     super(Model, self).__init__()
     self.num_codes = num_codes
     self.clinical_bert_sentences = EncoderSentences(
         ClinicalBertWrapper,
         pool_type="mean",
         truncate_tokens=50,
         truncate_sentences=1000,
         sentences_per_checkpoint=sentences_per_checkpoint,
         device=device)
     self.device = device
示例#4
0
 def __init__(self,
              outdim=64,
              sentences_per_checkpoint=10,
              device='cpu',
              cluster=False):
     super(Model, self).__init__()
     self.clinical_bert_sentences = EncoderSentences(
         ClinicalBertWrapper,
         pool_type="mean",
         truncate_tokens=50,
         truncate_sentences=1000,
         sentences_per_checkpoint=sentences_per_checkpoint,
         device=device)
     self.device = device
     self.cluster = cluster
     self.clusterer = Clusterer() if cluster else None
示例#5
0
 def __init__(self,
              outdim=64,
              sentences_per_checkpoint=10,
              device1='cpu',
              device2='cpu',
              freeze_bert=True,
              code_embedding_type_params=set([]),
              concatenate_code_embedding=False,
              dropout=.15,
              cluster=False):
     super(Model, self).__init__()
     self.clinical_bert_sentences = EncoderSentences(
         ClinicalBertWrapper,
         embedding_dim=outdim,
         truncate_tokens=50,
         truncate_sentences=1000,
         sentences_per_checkpoint=sentences_per_checkpoint,
         device=device1)
     if freeze_bert:
         self.freeze_bert()
     else:
         self.unfreeze_bert(dropout=dropout)
     self.code_embedding_type_params = code_embedding_type_params
     num_code_embedding_types = len(code_embedding_type_params)
     self.code_embeddings = nn.Embedding(code_embedding_type_params['codes'][0], outdim)\
                            if 'codes' in code_embedding_type_params.keys() else None
     self.linearized_code_transformer = EncoderSentences(lambda : LinearizedCodesTransformer(num_linearization_embeddings), embedding_dim=outdim, truncate_tokens=50,
                                                         truncate_sentences=1000, sentences_per_checkpoint=sentences_per_checkpoint, device=device2)\
                                        if 'linearized_codes' in code_embedding_type_params.keys() else None
     self.attention = nn.MultiheadAttention(outdim, 1)
     self.concatenate_code_embedding = concatenate_code_embedding
     self.linear = nn.Linear(outdim, 1)
     self.linear2 = nn.Linear(
         outdim * num_code_embedding_types,
         outdim) if num_code_embedding_types > 1 else None
     self.linear3 = nn.Linear(
         2 * outdim, outdim) if concatenate_code_embedding else None
     self.device1 = device1
     self.device2 = device2
     self.cluster = cluster
     self.clusterer = Clusterer() if cluster else None
示例#6
0
class Model(nn.Module):
    def __init__(self,
                 outdim=64,
                 sentences_per_checkpoint=10,
                 device1='cpu',
                 device2='cpu',
                 freeze_bert=True,
                 code_embedding_type_params=set([]),
                 concatenate_code_embedding=False,
                 dropout=.15,
                 cluster=False):
        super(Model, self).__init__()
        self.clinical_bert_sentences = EncoderSentences(
            ClinicalBertWrapper,
            embedding_dim=outdim,
            truncate_tokens=50,
            truncate_sentences=1000,
            sentences_per_checkpoint=sentences_per_checkpoint,
            device=device1)
        if freeze_bert:
            self.freeze_bert()
        else:
            self.unfreeze_bert(dropout=dropout)
        self.code_embedding_type_params = code_embedding_type_params
        num_code_embedding_types = len(code_embedding_type_params)
        self.code_embeddings = nn.Embedding(code_embedding_type_params['codes'][0], outdim)\
                               if 'codes' in code_embedding_type_params.keys() else None
        self.linearized_code_transformer = EncoderSentences(lambda : LinearizedCodesTransformer(num_linearization_embeddings), embedding_dim=outdim, truncate_tokens=50,
                                                            truncate_sentences=1000, sentences_per_checkpoint=sentences_per_checkpoint, device=device2)\
                                           if 'linearized_codes' in code_embedding_type_params.keys() else None
        self.attention = nn.MultiheadAttention(outdim, 1)
        self.concatenate_code_embedding = concatenate_code_embedding
        self.linear = nn.Linear(outdim, 1)
        self.linear2 = nn.Linear(
            outdim * num_code_embedding_types,
            outdim) if num_code_embedding_types > 1 else None
        self.linear3 = nn.Linear(
            2 * outdim, outdim) if concatenate_code_embedding else None
        self.device1 = device1
        self.device2 = device2
        self.cluster = cluster
        self.clusterer = Clusterer() if cluster else None

    def freeze_bert(self):
        set_dropout(self.clinical_bert_sentences, 0)
        set_require_grad(self.clinical_bert_sentences, False)

    def unfreeze_bert(self, dropout=.15):
        set_dropout(self.clinical_bert_sentences, dropout)
        set_require_grad(self.clinical_bert_sentences, True)

    def correct_devices(self):
        self.clinical_bert_sentences.correct_devices()
        if self.code_embeddings is not None:
            self.code_embeddings.to(self.device2)
        self.attention.to(self.device2)
        self.linear.to(self.device2)
        if self.linear2 is not None:
            self.linear2.to(self.device2)
        self.codes_per_checkpoint = 1000
        if self.cluster:
            self.clusterer.to(self.device2)
        if self.concatenate_code_embedding:
            self.linear3.to(self.device2)

    def forward(self,
                article_sentences,
                article_sentences_lengths,
                num_codes,
                codes=None,
                code_description=None,
                code_description_length=None,
                linearized_codes=None,
                linearized_codes_lengths=None,
                linearized_descriptions=None,
                linearized_descriptions_lengths=None):
        nq = num_codes.max()
        nq_temp = self.codes_per_checkpoint
        scores, attention, traceback_attention, context_vec = [], [], [], []
        for offset in range(0, nq, nq_temp):
            if codes is not None:
                codes_temp = codes[:, offset:offset + nq_temp]
            else:
                codes_temp = torch.zeros(0)
            if code_description is not None:
                code_description_temp = code_description[:, offset:offset +
                                                         nq_temp]
                code_description_length_temp = code_description_length[:,
                                                                       offset:
                                                                       offset +
                                                                       nq_temp]
            else:
                code_description_temp = torch.zeros(0)
                code_description_length_temp = torch.zeros(0)
            if linearized_codes is not None:
                linearized_codes_temp = linearized_codes[:, offset:offset +
                                                         nq_temp]
                linearized_codes_lengths_temp = linearized_codes_lengths[:,
                                                                         offset:
                                                                         offset
                                                                         +
                                                                         nq_temp]
            else:
                linearized_codes_temp = torch.zeros(0)
                linearized_codes_lengths_temp = torch.zeros(0)
            if linearized_descriptions is not None:
                linearized_descriptions_temp = linearized_descriptions[:,
                                                                       offset:
                                                                       offset +
                                                                       nq_temp]
                linearized_descriptions_lengths_temp = linearized_descriptions_lengths[:,
                                                                                       offset:
                                                                                       offset
                                                                                       +
                                                                                       nq_temp]
            else:
                linearized_descriptions_temp = torch.zeros(0)
                linearized_descriptions_lengths_temp = torch.zeros(0)
            num_codes_temp = torch.clamp(num_codes - offset, 0, nq_temp)
            scores_temp, attention_temp, traceback_attention_temp, context_vec_temp = checkpoint(
                self.inner_forward, article_sentences,
                article_sentences_lengths, num_codes_temp, codes_temp,
                code_description_temp, code_description_length_temp,
                linearized_codes_temp, linearized_codes_lengths_temp,
                linearized_descriptions_temp,
                linearized_descriptions_lengths_temp, *self.parameters())
            scores.append(scores_temp)
            attention.append(attention_temp)
            traceback_attention.append(traceback_attention_temp)
            context_vec.append(context_vec_temp)
        scores = torch.cat(scores, 1)
        attention = torch.cat(attention, 1)
        traceback_attention = torch.cat(traceback_attention, 1)
        context_vec = torch.cat(context_vec, 1)
        if self.cluster:
            clustering = self.clusterer(article_sentences,
                                        article_sentences_lengths, attention,
                                        num_codes)
        else:
            clustering = None
        return_dict = dict(scores=scores,
                           num_codes=num_codes,
                           attention=attention,
                           traceback_attention=traceback_attention,
                           article_sentences_lengths=article_sentences_lengths,
                           clustering=clustering,
                           context_vec=context_vec)
        if codes is not None:
            return_dict['codes'] = codes
        return return_dict

    def inner_forward(self, article_sentences, article_sentences_lengths,
                      num_codes, codes, code_description,
                      code_description_length, linearized_codes,
                      linearized_codes_lengths, linearized_descriptions,
                      linearized_descriptions_lengths, *args):
        codes, code_description, code_description_length, linearized_codes, linearized_codes_lengths, linearized_descriptions, linearized_descriptions_lengths = tensor_to_none(
            codes), tensor_to_none(code_description), tensor_to_none(
                code_description_length), tensor_to_none(
                    linearized_codes), tensor_to_none(
                        linearized_codes_lengths), tensor_to_none(
                            linearized_descriptions), tensor_to_none(
                                linearized_descriptions_lengths)
        encodings, self_attentions, word_level_attentions = self.clinical_bert_sentences(
            article_sentences, article_sentences_lengths)
        article_sentences_lengths, num_codes, encodings, self_attentions, word_level_attentions =\
            article_sentences_lengths.to(self.device2), num_codes.to(self.device2), encodings.to(self.device2), self_attentions.to(self.device2), word_level_attentions.to(self.device2)
        # b, ns, nt = word_level_attentions.shape
        b, ns, nl, nh, nt, _ = self_attentions.shape
        traceback_word_level_attentions = ta(
            self_attentions.mean(3).view(b*ns, nl, nt, nt),
            attention_vecs=word_level_attentions.view(b*ns, 1, nt))\
            .view(b, ns, nt)
        if codes is None and code_description is None and linearized_codes is None and linearized_descriptions is None:
            raise Exception
        all_code_embeddings = []
        if codes is not None:
            codes = codes.to(self.device2)
            all_code_embeddings.append(self.code_embeddings(codes))
        if code_description is not None:
            all_code_embeddings.append(
                self.clinical_bert_sentences(
                    code_description,
                    code_description_length,
                )[0].to(self.device2))
        if linearized_codes is not None:
            all_code_embeddings.append(
                self.linearized_code_transformer(linearized_codes,
                                                 linearized_codes_lengths)[0])
        if linearized_descriptions is not None:
            all_code_embeddings.append(
                self.clinical_bert_sentences(
                    linearized_descriptions,
                    linearized_descriptions_lengths,
                )[0].to(self.device2))
        if self.linear2 is not None:
            code_embeddings = torch.cat(all_code_embeddings, 2)
            code_embeddings = self.linear2(code_embeddings)
        else:
            code_embeddings = all_code_embeddings[0]
        key_padding_mask = (
            article_sentences_lengths == 0)[:, :encodings.size(1)]
        contextvec, sentence_level_attentions = self.attention(
            code_embeddings.transpose(0, 1),
            encodings.transpose(0, 1),
            encodings.transpose(0, 1),
            key_padding_mask=key_padding_mask)
        nq, _, emb_dim = contextvec.shape
        word_level_attentions = word_level_attentions\
            .view(b, 1, ns, nt)\
            .expand(b, nq, ns, nt)
        traceback_word_level_attentions = traceback_word_level_attentions\
            .view(b, 1, ns, nt)\
            .expand(b, nq, ns, nt)
        attention = word_level_attentions * sentence_level_attentions.unsqueeze(
            3)
        traceback_attention = traceback_word_level_attentions * sentence_level_attentions.unsqueeze(
            3)
        if self.concatenate_code_embedding:
            encoding = torch.cat(
                [contextvec, code_embeddings.transpose(0, 1)], 2)
            encoding = torch.relu(self.linear3(encoding))
        else:
            encoding = contextvec
        scores = self.linear(encoding)
        return scores.transpose(0, 1).squeeze(
            2), attention, traceback_attention, contextvec.transpose(0, 1)
示例#7
0
class Model(nn.Module):
    def __init__(self,
                 num_codes=0,
                 num_linearization_embeddings=0,
                 outdim=64,
                 sentences_per_checkpoint=10,
                 device1='cpu',
                 device2='cpu',
                 freeze_bert=True,
                 num_code_embedding_types=1,
                 dropout=.15):
        super(Model, self).__init__()
        self.num_codes = num_codes
        self.num_linearization_embeddings = num_linearization_embeddings
        self.clinical_bert_sentences = EncoderSentences(
            ClinicalBertWrapper,
            embedding_dim=outdim,
            truncate_tokens=50,
            truncate_sentences=1000,
            sentences_per_checkpoint=sentences_per_checkpoint,
            device=device1)
        if freeze_bert:
            self.freeze_bert()
        else:
            self.unfreeze_bert(dropout=dropout)
        self.code_embeddings = nn.Embedding(num_codes,
                                            outdim) if num_codes > 0 else None
        self.linearized_code_transformer = EncoderSentences(lambda : LinearizedCodesTransformer(num_linearization_embeddings), embedding_dim=outdim, truncate_tokens=50,
                                                            truncate_sentences=1000, sentences_per_checkpoint=sentences_per_checkpoint, device=device2)\
                                           if num_linearization_embeddings > 0 else None
        self.predict_targets = PredictTargets(outdim)
        self.linear2 = nn.Linear(
            outdim * num_code_embedding_types,
            outdim) if num_code_embedding_types > 1 else None
        self.device1 = device1
        self.device2 = device2

    def freeze_bert(self):
        set_dropout(self.clinical_bert_sentences, 0)
        set_require_grad(self.clinical_bert_sentences, False)

    def unfreeze_bert(self, dropout=.15):
        set_dropout(self.clinical_bert_sentences, dropout)
        set_require_grad(self.clinical_bert_sentences, True)

    def correct_devices(self):
        self.clinical_bert_sentences.correct_devices()
        if self.code_embeddings is not None:
            self.code_embeddings.to(self.device2)
        self.predict_targets.to(self.device2)
        if self.linear2 is not None:
            self.linear2.to(self.device2)
        self.codes_per_checkpoint = 1000

    def forward(self,
                article_sentences,
                article_sentences_lengths,
                num_codes,
                codes=None,
                code_description=None,
                code_description_length=None,
                linearized_codes=None,
                linearized_codes_lengths=None,
                linearized_descriptions=None,
                linearized_descriptions_lengths=None):
        nq = num_codes.max()
        nq_temp = self.codes_per_checkpoint
        scores, word_level_attentions, traceback_word_level_attentions, sentence_level_scores = [], [], [], []
        for offset in range(0, nq, nq_temp):
            if codes is not None:
                codes_temp = codes[:, offset:offset + nq_temp]
            else:
                codes_temp = torch.zeros(0)
            if code_description is not None:
                code_description_temp = code_description[:, offset:offset +
                                                         nq_temp]
                code_description_length_temp = code_description_length[:,
                                                                       offset:
                                                                       offset +
                                                                       nq_temp]
            else:
                code_description_temp = torch.zeros(0)
                code_description_length_temp = torch.zeros(0)
            if linearized_codes is not None:
                linearized_codes_temp = linearized_codes[:, offset:offset +
                                                         nq_temp]
                linearized_codes_lengths_temp = linearized_codes_lengths[:,
                                                                         offset:
                                                                         offset
                                                                         +
                                                                         nq_temp]
            else:
                linearized_codes_temp = torch.zeros(0)
                linearized_codes_lengths_temp = torch.zeros(0)
            num_codes_temp = torch.clamp(num_codes - offset, 0, nq_temp)
            scores_temp, word_level_attentions_temp, traceback_word_level_attentions_temp, sentence_level_scores_temp = checkpoint(
                self.inner_forward, article_sentences,
                article_sentences_lengths, num_codes_temp, codes_temp,
                code_description_temp, code_description_length_temp,
                linearized_codes_temp, linearized_codes_lengths_temp,
                *self.parameters())
            scores.append(scores_temp)
            word_level_attentions.append(word_level_attentions_temp)
            traceback_word_level_attentions.append(
                traceback_word_level_attentions_temp)
            sentence_level_scores.append(sentence_level_scores_temp)
        scores = torch.cat(scores, 1)
        word_level_attentions = torch.cat(word_level_attentions, 1)
        traceback_word_level_attentions = torch.cat(
            traceback_word_level_attentions, 1)
        sentence_level_scores = torch.cat(sentence_level_scores, 1)
        return_dict = dict(
            scores=scores,
            num_codes=num_codes,
            word_level_attentions=word_level_attentions,
            traceback_word_level_attentions=traceback_word_level_attentions,
            sentence_level_scores=sentence_level_scores,
            article_sentences_lengths=article_sentences_lengths)
        if codes is not None:
            return_dict['codes'] = codes
        return return_dict

    def inner_forward(self, article_sentences, article_sentences_lengths,
                      num_codes, codes, code_description,
                      code_description_length, linearized_codes,
                      linearized_codes_lengths, *args):
        codes, code_description, code_description_length, linearized_codes, linearized_codes_lengths = tensor_to_none(
            codes), tensor_to_none(code_description), tensor_to_none(
                code_description_length), tensor_to_none(
                    linearized_codes), tensor_to_none(linearized_codes_lengths)
        encodings, self_attentions, word_level_attentions = self.clinical_bert_sentences(
            article_sentences, article_sentences_lengths)
        article_sentences_lengths, num_codes, encodings, self_attentions, word_level_attentions =\
            article_sentences_lengths.to(self.device2), num_codes.to(self.device2), encodings.to(self.device2), self_attentions.to(self.device2), word_level_attentions.to(self.device2)
        # b, ns, nt = word_level_attentions.shape
        b, ns, nl, nh, nt, _ = self_attentions.shape
        traceback_word_level_attentions = ta(
            self_attentions.mean(3).view(b*ns, nl, nt, nt),
            attention_vecs=word_level_attentions.view(b*ns, 1, nt))\
            .view(b, ns, nt)
        if codes is None and code_description is None and linearized_codes is None:
            raise Exception
        all_code_embeddings = []
        if codes is not None:
            codes = codes.to(self.device2)
            all_code_embeddings.append(self.code_embeddings(codes))
        if code_description is not None:
            all_code_embeddings.append(
                self.clinical_bert_sentences(
                    code_description,
                    code_description_length,
                )[0].to(self.device2))
        if linearized_codes is not None:
            all_code_embeddings.append(
                self.linearized_code_transformer(linearized_codes,
                                                 linearized_codes_lengths)[0])
        if self.linear2 is not None:
            code_embeddings = torch.cat(all_code_embeddings, 2)
            code_embeddings = self.linear2(code_embeddings)
        else:
            code_embeddings = all_code_embeddings[0]
        sentence_level_scores = self.predict_targets(
            code_embeddings.transpose(0, 1),
            encodings.transpose(0, 1)).transpose(0, 2).transpose(1, 2)
        mask = (article_sentences_lengths !=
                0)[:, :encodings.size(1)].unsqueeze(1).expand(
                    sentence_level_scores.shape)
        #        sentence_level_scores_masked = sentence_level_scores*mask
        sentence_level_scores_masked = torch.zeros_like(
            sentence_level_scores).masked_scatter(mask,
                                                  sentence_level_scores[mask])
        scores = sentence_level_scores_masked.sum(2) / mask.sum(2)
        requires_grad = scores.requires_grad
        return scores, word_level_attentions.requires_grad_(
            requires_grad), traceback_word_level_attentions.requires_grad_(
                requires_grad), sentence_level_scores