Ejemplo n.º 1
0
                               max_translation=0.05,
                               scale_range=(0.9, 1.1),
                               max_angel=10)
    kp_model = KeyPointModelV2(backbone,
                               pixel_mean=0.5,
                               pixel_std=1,
                               loss=KeyPointBCELossV2(lamb=1),
                               spinal_model=spinal_model,
                               cascade_loss=CascadeLossV2(1),
                               loss_scaler=100,
                               num_cascades=3)
    kp_model.load_state_dict(torch.load('../models/pretrained101.kp_model'),
                             strict=False)

    dis_model = DiseaseModelBase(kp_model, sagittal_size=(512, 512))
    dis_model.cuda(0)
    print(dis_model)

    # 设定训练参数
    train_dataloader = DisDataLoader(train_studies,
                                     train_annotation,
                                     batch_size=8,
                                     num_workers=3,
                                     num_rep=20,
                                     prob_rotate=1,
                                     max_angel=180,
                                     sagittal_size=dis_model.sagittal_size,
                                     transverse_size=dis_model.sagittal_size,
                                     k_nearest=0,
                                     max_dist=6,
                                     sagittal_shift=1,
    kp_model = KeyPointModelV2(backbone,
                               pixel_mean=0.5,
                               pixel_std=1,
                               loss=KeyPointBCELossV2(lamb=1),
                               spinal_model=spinal_model,
                               cascade_loss=CascadeLossV2(1),
                               loss_scaler=100,
                               num_cascades=3)
    # 加载第一次训练好的模型
    kp_model.load_state_dict(torch.load('../models/pretrained_34.kp_model'),
                             strict=False)

    # 建立疾病分类模型
    dis_model = DiseaseModelBase(kp_model, sagittal_size=(512, 512))
    # 放置到GPU
    dis_model.cuda()
    # 打印模型参数
    print(dis_model)

    # 设定训练参数
    train_dataloader = DisDataLoader(train_studies,
                                     train_annotation,
                                     batch_size=8,
                                     num_workers=0,
                                     num_rep=20,
                                     prob_rotate=1,
                                     max_angel=180,
                                     sagittal_size=dis_model.sagittal_size,
                                     transverse_size=dis_model.sagittal_size,
                                     k_nearest=0,
                                     max_dist=6,
Ejemplo n.º 3
0
    train_images = {}
    for study_uid, study in train_studies.items():
        frame = study.t2_sagittal_middle_frame
        train_images[(study_uid, frame.series_uid, frame.instance_uid)] = frame.image

    backbone = resnet_fpn_backbone('resnet34', True)
    spinal_model = SpinalModel(train_images, train_annotation,
                               num_candidates=128, num_selected_templates=8,
                               max_translation=0.05, scale_range=(0.9, 1.1), max_angel=10)
    kp_model = KeyPointModelV2(backbone, pixel_mean=0.5, pixel_std=1,
                               loss=KeyPointBCELossV2(lamb=1), spinal_model=spinal_model,
                               cascade_loss=CascadeLossV2(1), loss_scaler=100, num_cascades=3)
    kp_model.load_state_dict(torch.load('../models/pretrained_34.kp_model'), strict=False)

    dis_model = DiseaseModelBase(kp_model, sagittal_size=(512, 512))
    dis_model.cuda(1)
    print(dis_model)

    # 设定训练参数
    train_dataloader = DisDataLoader(
        train_studies, train_annotation, batch_size=8, num_workers=5, num_rep=20, prob_rotate=1, max_angel=180,
        sagittal_size=dis_model.sagittal_size, transverse_size=dis_model.sagittal_size, k_nearest=0, max_dist=6,
        sagittal_shift=1, pin_memory=False
    )

    valid_evaluator = Evaluator(
        dis_model, valid_studies, '../data/lumbar_train51_annotation.json', num_rep=20, max_dist=6,
        metric='key point recall'
    )

    step_per_batch = len(train_dataloader)