def step(self, model: BertFold, batch: ProteinNetBatch) -> StepResult: targets = prepare_targets(batch) out = model.forward(batch, targets=targets) result = StepResult( n_processed=len(batch['input_ids']), **{k: v for k, v in out.items() if not k == 'y_hat'}) return result
) # %% if __name__ == '__main__': # %% from torch.utils.data import DataLoader from bert_fold.dataset import ProteinNetDataset, prepare_targets from bert_fold.dto.batch import ProteinNetBatch from const import DATA_PROTEIN_NET_DIR import pandas as pd # %% loader = DataLoader(ProteinNetDataset( pd.read_parquet(DATA_PROTEIN_NET_DIR / f'casp12/validation.pqt')), batch_size=2, collate_fn=ProteinNetDataset.collate, shuffle=False) # %% model = BertFold(pretrained=False) # %% batch: ProteinNetBatch = next(iter(loader)) targets = prepare_targets(batch) out = model.forward(batch, targets=targets) pass # %% for k, v in model.named_parameters(): print(k)