import torch
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

from code.core.disease import DisDataLoader, Evaluator, DiseaseModelV3, DisLoss
from code.core.key_point import SpinalModel, KeyPointModelV2, KeyPointBCELossV2, NullLoss, CascadeLossV2
from code.core.structure import construct_studies

sys.path.append('core/nn_tools/')
from core.nn_tools import torch_utils

if __name__ == '__main__':
    # 获取开始训练时时间
    start_time = time.time()
    # 多进程加载训练数据、数据标注和数据计数
    train_studies, train_annotation, train_counter = construct_studies(
        '../data/lumbar_train150',
        '../data/lumbar_train150_annotation.json',
        multiprocessing=False)
    # 多进程加载测试数据、数据标注和数据计数
    valid_studies, valid_annotation, valid_counter = construct_studies(
        '../data/lumbar_train51/',
        '../data/lumbar_train51_annotation.json',
        multiprocessing=False)

    # 设定模型参数
    train_images = {}
    # 遍历训练数据中的检查ID和对应的检查
    for study_uid, study in train_studies.items():
        frame = study.t2_sagittal_middle_frame
        train_images[(study_uid, frame.series_uid,
                      frame.instance_uid)] = frame.image
Exemplo n.º 2
0
import torch
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

from code.core.disease.data_loader import DisDataLoader
from code.core.disease.evaluation import Evaluator
from code.core.disease.model import DiseaseModelBase
from code.core.key_point import KeyPointModel, NullLoss
from code.core.structure import construct_studies

sys.path.append('../nn_tools/')
from nn_tools import torch_utils

if __name__ == '__main__':
    start_time = time.time()
    train_studies, train_annotation, train_counter = construct_studies(
        'data/lumbar_train150/',
        'data/lumbar_train150_annotation.json',
        multiprocessing=True)
    valid_studies, valid_annotation, valid_counter = construct_studies(
        'data/lumbar_train51/',
        'data/lumbar_train51_annotation.json',
        multiprocessing=True)

    # 设定模型参数
    train_images = {}
    for study_uid, study in train_studies.items():
        frame = study.t2_sagittal_middle_frame
        train_images[(study_uid, frame.series_uid,
                      frame.instance_uid)] = frame.image

    backbone = resnet_fpn_backbone('resnet50', True)
    kp_model = KeyPointModel(backbone)