class AudioTransformer(nn.Module): def __init__(self, d_model, nhead, dim_feedforward, num_layers, num_classes, dropout=0.1, use_conv_embedding=False, drop_input=0.1): super(AudioTransformer, self).__init__() self.use_conv_embedding = use_conv_embedding self.hidden_size = d_model if use_conv_embedding: self.conv_embedding = MSResNet(1) self.hidden_size = 768 self.drop_input = nn.Dropout(drop_input) self.config = BertConfig(hidden_size=self.hidden_size, num_hidden_layers=num_layers, intermediate_size=dim_feedforward, num_attention_heads=nhead, hidden_dropout_prob=dropout, output_attentions=True) self.encoder = BertModel(self.config) self.decoder = SimpleLinearClassifier(self.hidden_size, num_classes, dropout) def forward(self, x): if self.use_conv_embedding: assert x.shape[-1] == 256 batch_size = x.shape[0] # shape: (batch_size, seq_len, emb_size) x = x.reshape(-1, 1, 256) # shape: (batch_size * seq_len, 1, emb_size) where 1 is the conv number of channels x = self.conv_embedding( x) # returns a tuple with (classification, bottlenecks) x = x[1] # we just need second element x = x.reshape(batch_size, -1, 768) # shape: (batch_size, seq_len, conv_emb_size) x = self.drop_input(x) x = self.encoder.forward(inputs_embeds=x) # x = (hidden_states, pooled_output) where pooled means that the token is enforced to assume # the whole seq meaning. We are interested in the pooled output pooled = x[1] attentions = x[2] out = self.decoder(pooled) return out, attentions
class KorSTSModel(nn.Module): def __init__(self, bert_config: BertConfig, dropout_prob: float): super().__init__() self.config = bert_config self.bert = BertModel(bert_config) self.dropout = nn.Dropout(dropout_prob) self.classifier = nn.Linear(bert_config.hidden_size, 1) def forward(self, input_token_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor): _, pooled_output = self.bert.forward(input_token_ids, attention_mask, token_type_ids) output_drop = self.dropout(pooled_output) logits = self.classifier(output_drop) return logits
class BertFold(nn.Module): def __init__(self, pretrained: bool = True): super().__init__() if pretrained: self.bert = BertModel.from_pretrained('Rostlab/prot_bert') else: conf = BertConfig.from_pretrained('Rostlab/prot_bert') self.bert = BertModel(conf) # noinspection PyUnresolvedReferences dim = self.bert.config.hidden_size self.decoder_dist = PairwiseDistanceDecoder(dim) # self.decoder_phi = ElementwiseAngleDecoder(dim, 2) # self.decoder_psi = ElementwiseAngleDecoder(dim, 2) self.decoder_dist.apply(init_weights) # self.decoder_phi.apply(init_weights) # self.decoder_psi.apply(init_weights) def forward( self, input_ids, attention_mask=None, targets: Optional[BertFoldTargets] = None, ) -> BertFoldOutput: x = self.bert.forward(input_ids, attention_mask=attention_mask)[0] targets_dist = None if targets is None else targets.dist # targets_phi = None if targets is None else targets.phi # targets_psi = None if targets is None else targets.psi outs = [ self.decoder_dist.forward(x, targets_dist), # self.decoder_phi.forward(x, targets_phi), # self.decoder_psi.forward(x, targets_psi), ] y_hat = tuple(x.y_hat for x in outs) if targets is None: return BertFoldOutput(y_hat=y_hat, ) loss = torch.stack([x.loss for x in outs]).sum() # Collect metrics with torch.no_grad(): # Long range MAE metrics mae_l8_fn = MAEForSeq(contact_thre=8.) results = mae_l8_fn( inputs=y_hat[0][targets.dist.indices], targets=targets.dist.values, indices=targets.dist.indices, ) if len(results) > 0: mae_l_8 = (results.mean().detach().item(), len(results)) else: mae_l_8 = (0, 0) # Top L/5 precision metrics # top_l5_precision_fn = TopLNPrecision(n=5, contact_thre=8.) # results = top_l5_precision_fn( # inputs=out_dist.y_hat[targets.dist.indices], # targets=targets.dist.values, # indices=targets.dist.indices, # seq_lens=attention_mask.sum(-1) - 2, # ) # if len(results) > 0: # top_l5_precision = (results.mean().detach().item(), len(results)) # else: # top_l5_precision = (0, 0) return BertFoldOutput( y_hat=y_hat, loss=loss, loss_dist=outs[0].loss_and_cnt, # loss_phi=outs[1].loss_and_cnt, # loss_psi=outs[2].loss_and_cnt, mae_l_8=mae_l_8, )