Ejemplo n.º 1
0
class AudioTransformer(nn.Module):
    def __init__(self,
                 d_model,
                 nhead,
                 dim_feedforward,
                 num_layers,
                 num_classes,
                 dropout=0.1,
                 use_conv_embedding=False,
                 drop_input=0.1):
        super(AudioTransformer, self).__init__()
        self.use_conv_embedding = use_conv_embedding
        self.hidden_size = d_model
        if use_conv_embedding:
            self.conv_embedding = MSResNet(1)
            self.hidden_size = 768
        self.drop_input = nn.Dropout(drop_input)
        self.config = BertConfig(hidden_size=self.hidden_size,
                                 num_hidden_layers=num_layers,
                                 intermediate_size=dim_feedforward,
                                 num_attention_heads=nhead,
                                 hidden_dropout_prob=dropout,
                                 output_attentions=True)
        self.encoder = BertModel(self.config)
        self.decoder = SimpleLinearClassifier(self.hidden_size, num_classes,
                                              dropout)

    def forward(self, x):
        if self.use_conv_embedding:
            assert x.shape[-1] == 256
            batch_size = x.shape[0]
            # shape: (batch_size, seq_len, emb_size)
            x = x.reshape(-1, 1, 256)
            # shape: (batch_size * seq_len, 1, emb_size) where 1 is the conv number of channels
            x = self.conv_embedding(
                x)  # returns a tuple with (classification, bottlenecks)
            x = x[1]  # we just need second element
            x = x.reshape(batch_size, -1, 768)
            # shape: (batch_size, seq_len, conv_emb_size)
            x = self.drop_input(x)
        x = self.encoder.forward(inputs_embeds=x)
        # x = (hidden_states, pooled_output) where pooled means that the token is enforced to assume
        # the whole seq meaning. We are interested in the pooled output
        pooled = x[1]
        attentions = x[2]
        out = self.decoder(pooled)
        return out, attentions
Ejemplo n.º 2
0
class KorSTSModel(nn.Module):
    def __init__(self, bert_config: BertConfig, dropout_prob: float):
        super().__init__()
        self.config = bert_config

        self.bert = BertModel(bert_config)
        self.dropout = nn.Dropout(dropout_prob)
        self.classifier = nn.Linear(bert_config.hidden_size, 1)

    def forward(self, input_token_ids: torch.Tensor,
                attention_mask: torch.Tensor, token_type_ids: torch.Tensor):
        _, pooled_output = self.bert.forward(input_token_ids, attention_mask,
                                             token_type_ids)
        output_drop = self.dropout(pooled_output)
        logits = self.classifier(output_drop)

        return logits
Ejemplo n.º 3
0
class BertFold(nn.Module):
    def __init__(self, pretrained: bool = True):
        super().__init__()
        if pretrained:
            self.bert = BertModel.from_pretrained('Rostlab/prot_bert')
        else:
            conf = BertConfig.from_pretrained('Rostlab/prot_bert')
            self.bert = BertModel(conf)

        # noinspection PyUnresolvedReferences
        dim = self.bert.config.hidden_size

        self.decoder_dist = PairwiseDistanceDecoder(dim)
        # self.decoder_phi = ElementwiseAngleDecoder(dim, 2)
        # self.decoder_psi = ElementwiseAngleDecoder(dim, 2)

        self.decoder_dist.apply(init_weights)
        # self.decoder_phi.apply(init_weights)
        # self.decoder_psi.apply(init_weights)

    def forward(
        self,
        input_ids,
        attention_mask=None,
        targets: Optional[BertFoldTargets] = None,
    ) -> BertFoldOutput:
        x = self.bert.forward(input_ids, attention_mask=attention_mask)[0]

        targets_dist = None if targets is None else targets.dist
        # targets_phi = None if targets is None else targets.phi
        # targets_psi = None if targets is None else targets.psi

        outs = [
            self.decoder_dist.forward(x, targets_dist),
            # self.decoder_phi.forward(x, targets_phi),
            # self.decoder_psi.forward(x, targets_psi),
        ]

        y_hat = tuple(x.y_hat for x in outs)

        if targets is None:
            return BertFoldOutput(y_hat=y_hat, )

        loss = torch.stack([x.loss for x in outs]).sum()

        # Collect metrics
        with torch.no_grad():
            # Long range MAE metrics
            mae_l8_fn = MAEForSeq(contact_thre=8.)
            results = mae_l8_fn(
                inputs=y_hat[0][targets.dist.indices],
                targets=targets.dist.values,
                indices=targets.dist.indices,
            )
            if len(results) > 0:
                mae_l_8 = (results.mean().detach().item(), len(results))
            else:
                mae_l_8 = (0, 0)

            # Top L/5 precision metrics
            # top_l5_precision_fn = TopLNPrecision(n=5, contact_thre=8.)
            # results = top_l5_precision_fn(
            #     inputs=out_dist.y_hat[targets.dist.indices],
            #     targets=targets.dist.values,
            #     indices=targets.dist.indices,
            #     seq_lens=attention_mask.sum(-1) - 2,
            # )
            # if len(results) > 0:
            #     top_l5_precision = (results.mean().detach().item(), len(results))
            # else:
            #     top_l5_precision = (0, 0)

        return BertFoldOutput(
            y_hat=y_hat,
            loss=loss,
            loss_dist=outs[0].loss_and_cnt,
            # loss_phi=outs[1].loss_and_cnt,
            # loss_psi=outs[2].loss_and_cnt,
            mae_l_8=mae_l_8,
        )