import torch from torchvision.models.detection.backbone_utils import resnet_fpn_backbone from code.core.disease import DisDataLoader, Evaluator, DiseaseModelV3, DisLoss from code.core.key_point import SpinalModel, KeyPointModelV2, KeyPointBCELossV2, NullLoss, CascadeLossV2 from code.core.structure import construct_studies sys.path.append('core/nn_tools/') from core.nn_tools import torch_utils if __name__ == '__main__': # 获取开始训练时时间 start_time = time.time() # 多进程加载训练数据、数据标注和数据计数 train_studies, train_annotation, train_counter = construct_studies( '../data/lumbar_train150', '../data/lumbar_train150_annotation.json', multiprocessing=False) # 多进程加载测试数据、数据标注和数据计数 valid_studies, valid_annotation, valid_counter = construct_studies( '../data/lumbar_train51/', '../data/lumbar_train51_annotation.json', multiprocessing=False) # 设定模型参数 train_images = {} # 遍历训练数据中的检查ID和对应的检查 for study_uid, study in train_studies.items(): frame = study.t2_sagittal_middle_frame train_images[(study_uid, frame.series_uid, frame.instance_uid)] = frame.image
import torch from torchvision.models.detection.backbone_utils import resnet_fpn_backbone from code.core.disease.data_loader import DisDataLoader from code.core.disease.evaluation import Evaluator from code.core.disease.model import DiseaseModelBase from code.core.key_point import KeyPointModel, NullLoss from code.core.structure import construct_studies sys.path.append('../nn_tools/') from nn_tools import torch_utils if __name__ == '__main__': start_time = time.time() train_studies, train_annotation, train_counter = construct_studies( 'data/lumbar_train150/', 'data/lumbar_train150_annotation.json', multiprocessing=True) valid_studies, valid_annotation, valid_counter = construct_studies( 'data/lumbar_train51/', 'data/lumbar_train51_annotation.json', multiprocessing=True) # 设定模型参数 train_images = {} for study_uid, study in train_studies.items(): frame = study.t2_sagittal_middle_frame train_images[(study_uid, frame.series_uid, frame.instance_uid)] = frame.image backbone = resnet_fpn_backbone('resnet50', True) kp_model = KeyPointModel(backbone)