def train():
    args = util.get_args_shape_training()

    # Params / Config
    learning_rate = 1e-3
    momentums_cae = (0.99, 0.999)
    criterion = metrics.BatchDiceLoss([1.0])  # nn.BCELoss()
    path_training_metrics = args.continuetraining  # --continuetraining /share/data_zoe1/lucas/Linda_Segmentations/tmp/tmp_shape_f3.json
    path_saved_model = args.caepath
    channels_cae = args.channelscae
    n_globals = args.globals  # type(core/penu), tO_to_tA, NHISS, sex, age
    resample_size = int(args.xyoriginal * args.xyresample)
    pad = args.padding
    pad_value = 0
    leakage = 0.01
    cuda = True

    # CAE model
    enc = Enc3DCtp(size_input_xy=resample_size, size_input_z=args.zsize,
                   channels=channels_cae, n_ch_global=n_globals, leakage=leakage, padding=pad)
    dec = Dec3D(size_input_xy=resample_size, size_input_z=args.zsize,
                channels=channels_cae, n_ch_global=n_globals, leakage=leakage)
    cae = Cae3DCtp(enc, dec)
    if cuda:
        cae = cae.cuda()

    # Model params
    params = [p for p in cae.parameters() if p.requires_grad]
    print('# optimizing params', sum([p.nelement() * p.requires_grad for p in params]),
          '/ total: cae', sum([p.nelement() for p in cae.parameters()]))

    # Optimizer with scheduler
    optimizer = torch.optim.Adam(params, lr=learning_rate, weight_decay=1e-5, betas=momentums_cae)
    if args.lrsteps:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lrsteps)
    else:
        scheduler = None

    # Data
    common_transform = [data.ResamplePlaneXY(args.xyresample),
                        data.HemisphericFlipFixedToCaseId(split_id=args.hemisflipid),
                        data.PadImages(pad[0], pad[1], pad[2], pad_value=pad_value)]
    train_transform = common_transform + [data.ElasticDeform(), data.ToTensor()]
    valid_transform = common_transform + [data.ToTensor()]
    modalities = ['_CBV_reg1_downsampled', '_TTD_reg1_downsampled']
    labels = ['_CBVmap_subset_reg1_downsampled', '_TTDmap_subset_reg1_downsampled',
              '_FUCT_MAP_T_Samplespace_subset_reg1_downsampled']
    ds_train, ds_valid = data.get_stroke_shape_training_data(modalities, labels, train_transform, valid_transform,
                                                             args.fold, args.validsetsize, batchsize=args.batchsize)
    print('Size training set:', len(ds_train.sampler.indices), 'samples | Size validation set:', len(ds_valid.sampler.indices),
          'samples | Capacity batch:', args.batchsize, 'samples')
    print('# training batches:', len(ds_train), '| # validation batches:', len(ds_valid))

    # Training
    learner = CaeReconstructionLearner(ds_train, ds_valid, cae, path_saved_model, optimizer, scheduler,
                                       path_outputs_base=args.outbasepath)
    learner.run_training()
def test(args):
    # Params / Config
    modalities = ['_CBV_reg1_downsampled', '_TTD_reg1_downsampled']
    labels = [
        '_CBVmap_subset_reg1_downsampled', '_TTDmap_subset_reg1_downsampled',
        '_FUCT_MAP_T_Samplespace_subset_reg1_downsampled'
    ]
    normalization_hours_penumbra = args.normalize
    pad = args.padding
    pad_value = 0

    for idx in range(len(args.path)):
        # Data
        transform = [
            data.ResamplePlaneXY(args.xyresample),
            data.PadImages(pad[0], pad[1], pad[2], pad_value=pad_value),
            data.ToTensor()
        ]
        ds_test = data.get_testdata(modalities=modalities,
                                    labels=labels,
                                    transform=transform,
                                    indices=args.fold[idx])

        print('Size test set:', len(ds_test.sampler.indices), '| # batches:',
              len(ds_test))

        # Single case evaluation
        tester = CaeReconstructionTester(ds_test, args.path[idx],
                                         args.outbasepath,
                                         normalization_hours_penumbra)
        tester.run_inference()
Exemplo n.º 3
0
def test():
    args = util.get_args_shape_testing()

    assert len(args.fold) == len(
        args.path
    ), 'You must provide as many --fold arguments as caepath model arguments\
                                                in the exact same order!'

    # Params / Config
    modalities = ['_CBV_reg1_downsampled', '_TTD_reg1_downsampled']
    labels = [
        '_CBVmap_subset_reg1_downsampled', '_TTDmap_subset_reg1_downsampled',
        '_FUCT_MAP_T_Samplespace_subset_reg1_downsampled'
    ]
    normalization_hours_penumbra = args.normalize
    steps = range(6)  # fixed steps for tAdmission-->tReca: 0-5 hrs
    pad = args.padding
    pad_value = 0

    # Data
    transform = [
        data.ResamplePlaneXY(args.xyresample),
        data.PadImages(pad[0], pad[1], pad[2], pad_value=pad_value),
        data.ToTensor()
    ]

    # Fold-wise evaluation according to fold indices and fold model for all folds and model path provided as arguments:
    for i, path in enumerate(args.path):
        print('Model ' + path + ' of fold ' + str(i + 1) + '/' +
              str(len(args.fold)) + ' with indices: ' + str(args.fold[i]))
        ds_test = data.get_testdata(modalities=modalities,
                                    labels=labels,
                                    transform=transform,
                                    indices=args.fold[i])
        print('Size test set:', len(ds_test.sampler.indices), '| # batches:',
              len(ds_test))
        # Single case evaluation for all cases in fold
        tester = CaeReconstructionTesterCurve(ds_test, path, args.outbasepath,
                                              normalization_hours_penumbra,
                                              steps)
        tester.run_inference()
def test(args):

    # Params / Config
    modalities = ['_CBV_reg1_downsampled', '_TTD_reg1_downsampled']
    labels = ['_CBVmap_subset_reg1_downsampled', '_TTDmap_subset_reg1_downsampled']
    path_saved_model = args.unetpath
    pad = args.padding
    pad_value = 0

    # Data
    # Trained on patches, but fully convolutional approach let us apply on bigger image (thus, omit patch transform)
    transform = [data.ResamplePlaneXY(args.xyresample),
                 data.PadImages(pad[0], pad[1], pad[2], pad_value=pad_value),
                 data.ToTensor()]
    ds_test = data.get_testdata(modalities=modalities, labels=labels, transform=transform, indices=args.fold)

    print('Size test set:', len(ds_test.sampler.indices), '| # batches:', len(ds_test))

    # Single case evaluation
    tester = UnetSegmentationTester(ds_test, path_saved_model, args.outbasepath, None)
    tester.run_inference()
def train(args):
    # Params / Config
    learning_rate = 1e-3
    momentums_cae = (0.9, 0.999)
    weight_decay = 1e-5
    criterion = metrics.BatchDiceLoss([1.0])  # nn.BCELoss()
    channels_cae = args.channelscae
    n_globals = args.globals  # type(core/penu), tO_to_tA, NHISS, sex, age
    resample_size = int(args.xyoriginal * args.xyresample)
    alpha = 1.0
    cuda = True

    # CAE model
    cae = torch.load(args.caepath)
    cae.freeze(True)
    enc = Enc3DStep(size_input_xy=resample_size,
                    size_input_z=args.zsize,
                    channels=channels_cae,
                    n_ch_global=n_globals,
                    alpha=alpha)
    enc.encoder = cae.enc.encoder  # enc.step will be trained from scratch for given shape representations
    dec = cae.dec
    cae = Cae3D(enc, dec)

    if cuda:
        cae = cae.cuda()

    # Model params
    params = [p for p in cae.parameters() if p.requires_grad]
    print('# optimizing params',
          sum([p.nelement() * p.requires_grad for p in params]),
          '/ total: cae', sum([p.nelement() for p in cae.parameters()]))

    # Optimizer with scheduler
    optimizer = torch.optim.Adam(params,
                                 lr=learning_rate,
                                 weight_decay=weight_decay,
                                 betas=momentums_cae)
    if args.lrsteps:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, args.lrsteps)
    else:
        scheduler = None

    # Data
    common_transform = [data.ResamplePlaneXY(args.xyresample)
                        ]  # before: FixedToCaseId(split_id=args.hemisflipid)]
    train_transform = common_transform + [
        data.HemisphericFlip(),
        data.ElasticDeform(),
        data.ToTensor()
    ]
    valid_transform = common_transform + [data.ToTensor()]

    modalities = ['_CBV_reg1_downsampled', '_TTD_reg1_downsampled'
                  ]  # dummy data only needed for visualization
    labels = [
        '_CBVmap_subset_reg1_downsampled', '_TTDmap_subset_reg1_downsampled',
        '_FUCT_MAP_T_Samplespace_subset_reg1_downsampled'
    ]

    ds_train, ds_valid = data.get_stroke_shape_training_data(
        modalities,
        labels,
        train_transform,
        valid_transform,
        args.fold,
        args.validsetsize,
        batchsize=args.batchsize)
    print('Size training set:',
          len(ds_train.sampler.indices), 'samples | Size validation set:',
          len(ds_valid.sampler.indices), 'samples | Capacity batch:',
          args.batchsize, 'samples')
    print('# training batches:', len(ds_train), '| # validation batches:',
          len(ds_valid))

    # Training
    learner = CaeStepLearner(ds_train,
                             ds_valid,
                             cae,
                             optimizer,
                             scheduler,
                             n_epochs=args.epochs,
                             path_previous_base=args.inbasepath,
                             path_outputs_base=args.outbasepath,
                             criterion=criterion)
    learner.run_training()
def train():
    args = util.get_args_unet_training()

    # Params / Config
    batchsize = 6  # 17 training, 6 validation
    learning_rate = 1e-3
    momentums_cae = (0.99, 0.999)
    criterion = metrics.BatchDiceLoss([1.0])  # nn.BCELoss()
    path_saved_model = args.unetpath
    channels = args.channels
    pad = args.padding
    cuda = True

    # Unet model
    unet = Unet3D(channels)
    if cuda:
        unet = unet.cuda()

    # Model params
    params = [p for p in unet.parameters() if p.requires_grad]
    print('# optimizing params',
          sum([p.nelement() * p.requires_grad for p in params]),
          '/ total: unet', sum([p.nelement() for p in unet.parameters()]))

    # Optimizer with scheduler
    optimizer = torch.optim.Adam(params,
                                 lr=learning_rate,
                                 weight_decay=1e-5,
                                 betas=momentums_cae)
    if args.lrsteps:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, args.lrsteps)
    else:
        scheduler = None

    # Data
    train_transform = [
        data.ResamplePlaneXY(args.xyresample),
        data.HemisphericFlipFixedToCaseId(split_id=args.hemisflipid),
        data.PadImages(pad[0], pad[1], pad[2], pad_value=0),
        data.RandomPatch(104, 104, 68, pad[0], pad[1], pad[2]),
        data.ToTensor()
    ]
    valid_transform = [
        data.ResamplePlaneXY(args.xyresample),
        data.HemisphericFlipFixedToCaseId(split_id=args.hemisflipid),
        data.PadImages(pad[0], pad[1], pad[2], pad_value=0),
        data.RandomPatch(104, 104, 68, pad[0], pad[1], pad[2]),
        data.ToTensor()
    ]
    ds_train, ds_valid = data.get_stroke_shape_training_data(
        train_transform,
        valid_transform,
        args.fold,
        args.validsetsize,
        batchsize=batchsize)
    print('Size training set:',
          len(ds_train.sampler.indices), 'samples | Size validation set:',
          len(ds_valid.sampler.indices), 'samples | Capacity batch:',
          batchsize, 'samples')
    print('# training batches:', len(ds_train), '| # validation batches:',
          len(ds_valid))

    # Training
    learner = UnetSegmentationLearner(ds_train,
                                      ds_valid,
                                      unet,
                                      path_saved_model,
                                      optimizer,
                                      scheduler,
                                      criterion,
                                      path_previous_base=args.inbasepath,
                                      path_outputs_base=args.outbasepath)
    learner.run_training()
def infer():
    args = util.get_args_sdm()

    print('Evaluate validation set', args.fold)

    # Params / Config
    normalization_hours_penumbra = 10
    #channels_unet = args.channels  TODO Unet live segmentation
    #pad = args.padding  TODO Unet live segmentation

    transform = [data.ResamplePlaneXY(args.xyresample),
                 data.HemisphericFlipFixedToCaseId(split_id=args.hemisflipid),
                 #data.PadImages(pad[0], pad[1], pad[2], pad_value=0),  TODO Unet live segmentation
                 data.ToTensor()]

    ds_test = data.get_testdata(modalities=['_unet_core', '_unet_penu'],  # modalities=['_CBV_reg1_downsampled', '_TTD_reg1_downsampled'],  TODO Unet live segmentation
                                labels=['_CBVmap_subset_reg1_downsampled', '_TTDmap_subset_reg1_downsampled',
                                        '_FUCT_MAP_T_Samplespace_subset_reg1_downsampled'],
                                transform=transform,
                                indices=args.fold)

    # Unet
    #unet = None  TODO Unet live segmentation
    #if not args.groundtruth:  TODO Unet live segmentation
    #    unet = Unet3D(channels=channels_unet)  TODO Unet live segmentation
    #    unet.load_state_dict(torch.load(args.unet))  TODO Unet live segmentation
    #    unet.train(False)  # fixate regularization for forward-only!  TODO Unet live segmentation

    for sample in ds_test:
        case_id = sample[data.KEY_CASE_ID].cpu().numpy()[0]

        nifph = nib.load('/share/data_zoe1/lucas/Linda_Segmentations/' + str(case_id) + '/train' + str(case_id) +
                         '_CBVmap_reg1_downsampled.nii.gz').affine

        to_to_ta, normalization = get_normalized_time(sample, normalization_hours_penumbra)

        lesion = Variable(sample[data.KEY_LABELS][:, 2, :, :, :].unsqueeze(data.DIM_CHANNEL_TORCH3D_5))
        if args.groundtruth:
            core = Variable(sample[data.KEY_LABELS][:, 0, :, :, :].unsqueeze(data.DIM_CHANNEL_TORCH3D_5))
            penu = Variable(sample[data.KEY_LABELS][:, 1, :, :, :].unsqueeze(data.DIM_CHANNEL_TORCH3D_5))
        else:
            #dto = UnetDtoInit.init_dto(Variable(sample[data.KEY_IMAGES]), None, None)  TODO Unet live segmentation
            #dto = unet(dto)  TODO Unet live segmentation
            #core = dto.outputs.core  TODO Unet live segmentation
            #penu = dto.outputs.penu,  TODO Unet live segmentation
            core = Variable(sample[data.KEY_IMAGES][:, 0, :, :, :].unsqueeze(data.DIM_CHANNEL_TORCH3D_5))
            penu = Variable(sample[data.KEY_IMAGES][:, 1, :, :, :].unsqueeze(data.DIM_CHANNEL_TORCH3D_5))

        ta_to_tr = sample[data.KEY_GLOBAL][:, 1, :, :, :].squeeze().unsqueeze(data.DIM_CHANNEL_TORCH3D_5)
        time_to_treatment = Variable(ta_to_tr.type(torch.FloatTensor) / normalization)

        del to_to_ta
        del normalization

        recon_core, recon_intp, recon_penu, latent_core, latent_intp, latent_penu = \
            sdm_interpolate_numpy(core.data.cpu().numpy(), penu.data.cpu().numpy(), threshold=0.5,
                                  interpolation=time_to_treatment.data.cpu().numpy().squeeze(), zoom=12,
                                  resample=args.downsample)

        print(int(sample[data.KEY_CASE_ID]), 'TO-->TR', float(time_to_treatment))

        if args.visualinspection:
            fig, axes = plt.subplots(3, 4)

            axes[0, 0].imshow(core.cpu().data.numpy()[0, 0, 16, :, :], cmap='gray', vmin=0, vmax=1)
            axes[1, 0].imshow(lesion.cpu().data.numpy()[0, 0, 16, :, :], cmap='gray', vmin=0, vmax=1)
            axes[2, 0].imshow(penu.cpu().data.numpy()[0, 0, 16, :, :], cmap='gray', vmin=0, vmax=1)

            axes[0, 1].imshow(latent_core[16, :, :], cmap='gray')
            axes[1, 1].imshow(latent_intp[16, :, :], cmap='gray')
            axes[2, 1].imshow(latent_penu[16, :, :], cmap='gray')

            axes[0, 2].imshow(recon_core[16, :, :], cmap='gray')
            axes[1, 2].imshow(recon_intp[16, :, :], cmap='gray')
            axes[2, 2].imshow(recon_penu[16, :, :], cmap='gray')

            axes[0, 3].imshow(recon_core[16, :, :] < 0, cmap='gray', vmin=0, vmax=1)
            axes[1, 3].imshow(recon_intp[16, :, :] > 0, cmap='gray', vmin=0, vmax=1)
            axes[2, 3].imshow(recon_penu[16, :, :] > 0, cmap='gray', vmin=0, vmax=1)
            plt.show()

        results = metrics.binary_measures_numpy((recon_intp > 0).astype(np.float),
                                                lesion.cpu().data.numpy()[0, 0, :, :, :], binary_threshold=0.5)

        c_res = metrics.binary_measures_numpy((recon_core < 0).astype(np.float),
                                               core.cpu().data.numpy()[0, 0, :, :, :], binary_threshold=0.5)

        p_res = metrics.binary_measures_numpy((recon_penu > 0).astype(np.float),
                                               penu.cpu().data.numpy()[0, 0, :, :, :], binary_threshold=0.5)

        with open('/data_zoe1/lucas/Linda_Segmentations/tmp/sdm_results.txt', 'a') as f:
            print('Evaluate case: {} - DC:{:.3}, HD:{:.3}, ASSD:{:.3}, Core recon DC:{:.3}, Penu recon DC:{:.3}'.format(case_id,
                results.dc, results.hd,  results.assd, c_res.dc, p_res.dc), file=f)

        zoomed = ndi.interpolation.zoom(recon_intp.transpose((2, 1, 0)), zoom=(2, 2, 1))
        nib.save(nib.Nifti1Image((zoomed > 0).astype(np.float32), nifph), args.outbasepath + '_' + str(case_id) + '_lesion.nii.gz')
        del zoomed

        zoomed = ndi.interpolation.zoom(lesion.cpu().data.numpy().astype(np.int8).transpose((4, 3, 2, 1, 0))[:, :, :, 0, 0], zoom=(2, 2, 1))
        nib.save(nib.Nifti1Image(zoomed, nifph), args.outbasepath + '_' + str(case_id) + '_fuctgt.nii.gz')
        del zoomed

        zoomed = ndi.interpolation.zoom(recon_core.transpose((2, 1, 0)), zoom=(2, 2, 1))
        nib.save(nib.Nifti1Image((zoomed < 0).astype(np.float32), nifph), args.outbasepath + '_' + str(case_id) + '_core.nii.gz')
        del zoomed

        zoomed = ndi.interpolation.zoom(recon_penu.transpose((2, 1, 0)), zoom=(2, 2, 1))
        nib.save(nib.Nifti1Image((zoomed > 0).astype(np.float32), nifph), args.outbasepath + '_' + str(case_id) + '_penu.nii.gz')

        del nifph

        del sample
Exemplo n.º 8
0
def train(args):
    # Params / Config
    learning_rate = 1e-3
    momentums_cae = (0.9, 0.999)
    weight_decay = 1e-5
    criterion = metrics.BatchDiceLoss([1.0])  # nn.BCELoss()
    resample_size = int(args.xyoriginal * args.xyresample)
    n_globals = args.globals  # type(core/penu), tO_to_tA, NHISS, sex, age
    channels_enc = args.channelsenc
    alpha = 1.0
    cuda = True

    # TODO assert initbycae XOR channels_enc

    # CAE model
    path_saved_model = args.caepath
    cae = torch.load(path_saved_model)
    cae.freeze(True)
    if args.initbycae:
        enc = torch.load(path_saved_model).enc
    else:
        enc = Enc3D(size_input_xy=resample_size,
                    size_input_z=args.zsize,
                    channels=channels_enc,
                    n_ch_global=n_globals,
                    alpha=alpha)

    if cuda:
        cae = cae.cuda()
        enc = enc.cuda()

    # Model params
    params = [p for p in enc.parameters() if p.requires_grad]
    print(
        '# optimizing params',
        sum([p.nelement() * p.requires_grad for p in params]),
        '/ total new enc + old cae',
        sum([p for p in enc.parameters()] +
            [p.nelement() for p in cae.parameters()]))

    # Optimizer with scheduler
    optimizer = torch.optim.Adam(params,
                                 lr=learning_rate,
                                 weight_decay=weight_decay,
                                 betas=momentums_cae)
    if args.lrsteps:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, args.lrsteps)
    else:
        scheduler = None

    # Data
    common_transform = [
        data.ResamplePlaneXY(args.xyresample),
        data.HemisphericFlipFixedToCaseId(split_id=args.hemisflipid)
    ]
    train_transform = common_transform + [
        data.ElasticDeform(apply_to_images=True),
        data.ToTensor()
    ]
    valid_transform = common_transform + [data.ToTensor()]
    modalities = ['_unet_core', '_unet_penu']
    labels = [
        '_CBVmap_subset_reg1_downsampled', '_TTDmap_subset_reg1_downsampled',
        '_FUCT_MAP_T_Samplespace_subset_reg1_downsampled'
    ]
    ds_train, ds_valid = data.get_stroke_prediction_training_data(
        modalities,
        labels,
        train_transform,
        valid_transform,
        args.fold,
        args.validsetsize,
        batchsize=args.batchsize)
    print('Size training set:',
          len(ds_train.sampler.indices), 'samples | Size validation set:',
          len(ds_valid.sampler.indices), 'samples | Capacity batch:',
          args.batchsize, 'samples')
    print('# training batches:', len(ds_train), '| # validation batches:',
          len(ds_valid))

    # Training
    learner = CaePredictionLearner(ds_train,
                                   ds_valid,
                                   cae,
                                   enc,
                                   optimizer,
                                   scheduler,
                                   n_epochs=args.epochs,
                                   path_previous_base=args.inbasepath,
                                   path_outputs_base=args.outbasepath,
                                   criterion=criterion)
    learner.run_training()