Exemplo n.º 1
0
 def apply_affine_transform(self, patch, rotation_params,
                            translation_params):
     transform = self.get_transform(rotation_params, translation_params)
     A = torch.Tensor(transform.GetMatrix()).view(-1, 3, 3)
     b = torch.Tensor(transform.GetTranslation()).view(-1, 3, 1)
     theta = torch.cat((A, b * self.voxelsize), dim=2)
     transformed_patch = affine_transform(patch, theta)
     return transformed_patch
Exemplo n.º 2
0
def validate(fixed_patches, moving_patches, epoch, model, criterion, weight,
             device):
    """Validating the model using part of the dataset
        Args:
            fixed_patches (Tensor): Tensor holding the fixed_patches ([num_patches, 1, patch_size, patch_size, patch_size])
            moving_patches (Tensor): Tensor holding the moving patches ([num_patches, 1, patch_size, patch_size, patch_size])
            epoch (int): current epoch
            model (nn.Module): Network model
            criterion (nn.Module): Loss-function
            weight (float): float number to weight the regularizer in the loss function
        Returns:
            array of validation losses over each batch
    """

    validation_set = CreateDataset(fixed_patches, moving_patches)
    validation_loader = DataLoader(validation_set,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=0,
                                   pin_memory=True,
                                   drop_last=False)

    validation_loss = torch.zeros(len(validation_loader), device=device)

    for batch_idx, (fixed_batch, moving_batch) in enumerate(validation_loader):

        fixed_batch, moving_batch = fixed_batch.to(device), moving_batch.to(
            device)

        predicted_theta = model(fixed_batch, moving_batch)
        predicted_deform = affine_transform(moving_batch, predicted_theta)

        loss, cross_corr = criterion(fixed_batch,
                                     predicted_deform,
                                     predicted_theta,
                                     weight,
                                     reduction='mean')
        validation_loss[batch_idx] = cross_corr.item()

        printer = progress_printer((batch_idx + 1) / len(validation_loader))
        print(printer + ' Validating epoch {:2}/{} (steps: {})'.format(
            epoch + 1, args.epochs, len(validation_loader)),
              end='\r')

    print('\n')

    return validation_loss
def main():
    global args, user_config

    args = parse()
    user_config = UserConfigParser()

    # Read dataset information from file
    dataset = GetDatasetInformation(os.path.join(user_config.DATA_ROOT, args.frame), args.aft, mode='prediction')

    # Append correct files
    fixed_image = dataset.fix_files
    moving_image = dataset.mov_files
    fix_vol = dataset.fix_vols
    mov_vol = dataset.mov_vols

    dims = dataset.get_biggest_dimensions()

    voxelsize = 7.000003e-4

    # Get correct paths
    data_files = os.path.join(user_config.DATA_ROOT, 'patient_data_proc_{}/'.format(args.aft))
    theta_proc_path = os.path.join(user_config.PROJECT_ROOT, user_config.PROCRUSTES, 'results', args.glob)

    # Get ultrasound volume data
    vol_data = LoadHDF5File(data_files,
                            fixed_image[args.PSN - 1], moving_image[args.PSN - 1],
                            fix_vol[args.PSN - 1], mov_vol[args.PSN - 1], dims=None)

    vol_data.normalize()

    # Add batch dimension
    fixed_volume = vol_data.fix_data.unsqueeze(0)
    moving_volume = vol_data.mov_data.unsqueeze(0)

    # Reading global theta from file
    global_theta = []
    with open(theta_proc_path, 'r') as readTheta:
        for i, theta in enumerate(readTheta.read().split()):
            if theta != '1' and theta != '0':
                if i == 3 or i == 7 or i == 11:
                    # Append translation values
                    global_theta.append(float(theta) * voxelsize * 10)
                else:
                    # Append rotation values
                    global_theta.append(float(theta))

    global_theta = torch.Tensor(global_theta)
    global_theta = global_theta.view(-1, 3, 4)  # Get theta on correct form for affine transform
    print('Global theta:')
    print(global_theta)
    print('\n')

    warped_volume = affine_transform(moving_volume, global_theta)

    # Compute loss pre- and post-alignment
    pre_loss, pre_mask = unmasked_normalized_cross_correlation(fixed_volume, moving_volume, reduction=None)
    post_loss, post_mask = masked_normalized_cross_correlation(fixed_volume, warped_volume, reduction=None)

    print('\n')
    print('{}'.format('Post alignment values'))
    print('*' * 100)
    print('Prediction set' + ' | ' + 'NCC before warping' + ' | ' + 'NCC after warping' + ' | ' +
          'Improvement' + ' | ' + 'Percentwice imp.')
    print('{:<8}{:>20}{:>20}{:>18}{:>13}%'.format(args.PSN,
                                                  round(pre_loss.item(), 4),
                                                  round(post_loss.item(), 4),
                                                  round((post_loss.item() - pre_loss.item()), 4),
                                                  round(100 - ((pre_loss.item() / post_loss.item()) * 100), 2)))

    plot_volumes(fixed_volume, moving_volume, warped_volume, pre_mask, post_mask)
Exemplo n.º 4
0
def main():
    global args, user_config

    user_config = UserConfigParser()  # Parse main_config.ini
    args = parse()

    #torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

    # GPU configuration
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices

    voxelsize = 7.0000003e-4
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    uname = platform.uname()
    print('\n')
    print('Initializing prediction with the following configuration:')
    print('\n')
    print("=" * 40, "System Information", "=" * 40)
    print(f"System: {uname.system}")
    print(f"Node Name: {uname.node}")
    print(f"Release: {uname.release}")
    print(f"Version: {uname.version}")
    print(f"Machine: {uname.machine}")
    print(f"Processor: {uname.processor}")
    print("=" * 40, "GPU Information", "=" * 40)
    print(f"CUDA_VISIBLE_DEVICES: {args.cuda_visible_devices}")
    print(f"Device: {device}")
    print("=" * 40, "Parameters", "=" * 40)
    print(f"Model name: {args.model_name}")
    print(f"Batch size: {user_config.batch_size}")
    print(f"Patch size: {user_config.patch_size}")
    print(f"Stride: {user_config.stride}")
    print(f"Filter type: {args.filter_type}")
    print(f"Frame: End-systole") if args.frame == 'end_systole.csv' else print(
        f"Frame: End-diastole")
    print(f"Prediction precision: float32") if apexImportError else print(
        f"Prediction precision: {args.precision}")
    print('\n')

    model_name = os.path.join(user_config.PROJECT_ROOT,
                              user_config.PROJECT_NAME,
                              'output/models/{}.pt'.format(args.model_name))
    data_files = os.path.join(user_config.DATA_ROOT,
                              'patient_data_proc_{}/'.format(args.filter_type))

    posFile = os.path.join(user_config.PROJECT_ROOT, 'procrustes_analysis',
                           'loc_predictions',
                           'loc_prediction_{}.csv'.format(args.model_name))
    thetaFile = os.path.join(user_config.PROJECT_ROOT, 'procrustes_analysis',
                             'theta_predictions',
                             'theta_prediction_{}.csv'.format(args.model_name))

    predictionStorage = FileHandler(posFile, thetaFile)
    predictionStorage.create()

    # Configuration of the model
    model_config = network_config()

    encoder = _Encoder(**model_config['ENCODER_CONFIG'])
    affineRegression = _AffineRegression(**model_config['AFFINE_CONFIG'])
    model = USARNet(encoder, affineRegression).to(device)

    # Load model with existing weights
    print('Loading weights ...')
    loadModel = torch.load(model_name, map_location=device)
    model.load_state_dict(loadModel['model_state_dict'])

    # Decide on FP32 prediction or mixed precision
    if args.precision == 'amp' and not apexImportError:
        model = amp.initialize(model, opt_level='O2')
    elif args.precision == 'amp' and apexImportError:
        print(
            'Error: Apex not found, cannot go ahead with mixed precision prediction. Continuing with full precision.'
        )

    # Generate prediction data
    fixed_patches, moving_patches, loc = generate_prediction_patches(
        DATA_ROOT=user_config.DATA_ROOT,
        data_files=data_files,
        frame=args.frame,
        filter_type=args.filter_type,
        patch_size=user_config.patch_size,
        stride=user_config.stride,
        device=device,
        PSN=args.PSN)

    print('\n')
    print('Number of prediction samples: {}'.format(fixed_patches.shape[0]))
    print('\n')

    # Create dataset of prediction data for use with the DataLoader
    prediction_set = CreateDataset(fixed_patches, moving_patches, loc)
    prediction_loader = DataLoader(prediction_set,
                                   batch_size=user_config.batch_size,
                                   shuffle=False,
                                   num_workers=0,
                                   pin_memory=False,
                                   drop_last=False)

    # Set correct d-type
    dtype = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.FloatTensor

    # Create empty tensors for writing position and prediction data
    predicted_theta_tmp = torch.Tensor(1, user_config.batch_size,
                                       12).type(dtype).to(device)
    pos_tmp = torch.Tensor(1, user_config.batch_size, 3).type(dtype).to(device)

    sampleNumber = 1  # Hold index for writing correctly to .h5 file
    saveData = SaveHDF5File(
        user_config.DATA_ROOT)  # Initialize file for saving patches

    # Variables to store loss
    pre_ncc = 0
    post_ncc = 0

    print('Predicting')

    # No grads to be computed
    with torch.no_grad():

        # Eval mode for prediction
        model.eval()

        for batch_idx, (fixed_batch, moving_batch,
                        loc) in enumerate(prediction_loader):

            # Run and time prediction model
            torch.cuda.synchronize()
            model_rt = time.time()
            predicted_theta = model(fixed_batch, moving_batch)
            torch.cuda.synchronize()
            end_rt = time.time()
            print('Model runtime: ', end_rt - model_rt)

            # Transform moving data based on prediction
            warped_batch = affine_transform(moving_batch, predicted_theta)

            # Saves the patches as HDF5 data
            if args.save_data:
                saveData.save_hdf5(fixed_batch=fixed_batch,
                                   moving_batch=moving_batch,
                                   warped_batch=warped_batch,
                                   sampleNumber=sampleNumber)

            # Updata samplenumber, used for saving of predictions
            sampleNumber += user_config.batch_size

            # Compute normalized cross correlation values
            preWarpNcc, pre_mask = unmasked_normalized_cross_correlation(
                fixed_batch, moving_batch, reduction=None)
            postWarpNcc, post_mask = masked_normalized_cross_correlation(
                fixed_batch, warped_batch, reduction=None)

            # Plot predictions for each patch
            if args.plot_patchwise_prediction:
                plotPatchwisePrediction(fixed_batch=fixed_batch.cpu(),
                                        moving_batch=moving_batch.cpu(),
                                        predicted_theta=predicted_theta.cpu(),
                                        PROJ_ROOT=user_config.PROJECT_ROOT,
                                        PROJ_NAME=user_config.PROJECT_NAME,
                                        pre_mask=pre_mask.cpu(),
                                        post_mask=post_mask.cpu())

            predicted_theta = predicted_theta.view(-1, 12)
            predicted_theta_tmp = predicted_theta.type(dtype)
            loc_tmp = loc.type(dtype)

            # Store position and prediction of patches
            predictionStorage.write(loc=loc_tmp.cpu().numpy().round(5),
                                    theta=predicted_theta_tmp.cpu().numpy())

            print_patchloss(preWarpNcc, postWarpNcc)

            pre_ncc += torch.sum(preWarpNcc, 0)
            post_ncc += torch.sum(postWarpNcc, 0)

        pre_av = torch.div(pre_ncc, fixed_patches.shape[0]).item()
        post_av = torch.div(post_ncc, fixed_patches.shape[0]).item()

        print('\n')
        print('{}'.format('Averaged values'))
        print('*' * 100)
        print('Number of patches' + ' | ' + 'NCC before warping' + ' | ' +
              'NCC after warping' + ' | ' + 'Improvement' + ' | ' +
              'Percentwice imp.')
        print('{:<12}{:>20}{:>20}{:>20}{:>13}%'.format(
            fixed_patches.shape[0], round(pre_av, 4), round(post_av, 4),
            round((post_av - pre_av), 4),
            round(100 - ((pre_av / post_av) * 100), 2)))
def plotTrainPredictions(fixed_batch, moving_batch, predicted_theta, mask, PROJ_ROOT, PROJ_NAME, savefig=False, copperAlpha=1, grayAlpha=0.6):
    batch_size = fixed_batch.shape[0]
    warped_batch = affine_transform(moving_batch, predicted_theta)

    fixed_batch = fixed_batch.detach().numpy()
    moving_batch = moving_batch.detach().numpy()
    warped_batch = warped_batch.detach().numpy()
    mask = mask.detach().numpy()

    x_range = batch_size
    y_range = 2

    fig_x, ax_x = plt.subplots(y_range, x_range, squeeze=False, figsize=(40, 40))
    fig_y, ax_y = plt.subplots(y_range, x_range, squeeze=False, figsize=(40, 40))
    fig_z, ax_z = plt.subplots(y_range, x_range, squeeze=False, figsize=(40, 40))

    count = 0
    for i in range(y_range):
        for j in range(x_range):
            ax_x[i, j].get_xaxis().set_visible(False)
            ax_x[i, j].get_yaxis().set_visible(False)
            ax_y[i, j].get_xaxis().set_visible(False)
            ax_y[i, j].get_yaxis().set_visible(False)
            ax_z[i, j].get_xaxis().set_visible(False)
            ax_z[i, j].get_yaxis().set_visible(False)

            ax_x[i, j].set_xlim([0, fixed_batch.shape[2]])
            ax_x[i, j].set_ylim([fixed_batch.shape[2], 0])
            ax_y[i, j].set_xlim([0, fixed_batch.shape[2]])
            ax_y[i, j].set_ylim([fixed_batch.shape[2], 0])
            ax_z[i, j].set_xlim([0, fixed_batch.shape[2]])
            ax_z[i, j].set_ylim([fixed_batch.shape[2], 0])

            if count != batch_size:
                fixed_x = fixed_batch[count, 0, fixed_batch.shape[2] // 2]
                fixed_y = fixed_batch[count, 0, :, fixed_batch.shape[3] // 2]
                fixed_z = fixed_batch[count, 0, :, :, fixed_batch.shape[4] // 2]

                moving_x = moving_batch[count, 0, moving_batch.shape[2] // 2]
                moving_y = moving_batch[count, 0, :, moving_batch.shape[3] // 2]
                moving_z = moving_batch[count, 0, :, :, moving_batch.shape[4] // 2]

                warped_x = warped_batch[count, 0, warped_batch.shape[2] // 2]
                warped_y = warped_batch[count, 0, :, warped_batch.shape[3] // 2]
                warped_z = warped_batch[count, 0, :, :, warped_batch.shape[4] // 2]

                mask_x = mask[count, 0, mask.shape[2] // 2]
                mask_y = mask[count, 0, :, mask.shape[3] // 2]
                mask_z = mask[count, 0, :, :, mask.shape[4] // 2]

                # Plot x-slixed predictions
                ax_x[0, j].imshow(fixed_x, origin='left', cmap='copper', alpha=copperAlpha)
                ax_x[0, j].imshow(moving_x, origin='lef', cmap='gray', alpha=grayAlpha)
                ax_x[1, j].imshow(fixed_x, origin='left', cmap='copper', alpha=copperAlpha)
                ax_x[1, j].imshow(warped_x, origin='lef', cmap='gray', alpha=grayAlpha)
                ax_x[0, j].title.set_text('No alignment')
                ax_x[1, j].title.set_text('Predicted alignment')

                # Plot y-slixed predictions
                ax_y[0, j].imshow(fixed_y, origin='left', cmap='copper', alpha=copperAlpha)
                ax_y[0, j].imshow(moving_y, origin='lef', cmap='gray', alpha=grayAlpha)
                ax_y[1, j].imshow(fixed_y, origin='left', cmap='copper', alpha=copperAlpha)
                ax_y[1, j].imshow(warped_y, origin='lef', cmap='gray', alpha=grayAlpha)
                ax_y[0, j].title.set_text('No alignment')
                ax_y[1, j].title.set_text('Predicted alignment')

                # Plot z-slixed predictions
                ax_z[0, j].imshow(fixed_z, origin='left', cmap='copper', alpha=copperAlpha)
                ax_z[0, j].imshow(moving_z, origin='lef', cmap='gray', alpha=grayAlpha)
                ax_z[1, j].imshow(fixed_z, origin='left', cmap='copper', alpha=copperAlpha)
                ax_z[1, j].imshow(warped_z, origin='lef', cmap='gray', alpha=grayAlpha)
                ax_z[0, j].title.set_text('No alignment')
                ax_z[1, j].title.set_text('Predicted alignment')

                ax_x[1, j].imshow(mask_x, origin='left', cmap='cool', alpha=0.1)
                ax_y[1, j].imshow(mask_y, origin='left', cmap='cool', alpha=0.1)
                ax_z[1, j].imshow(mask_z, origin='left', cmap='cool', alpha=0.1)

                count += 1

    fig_x.suptitle('x-sliced patchwise predictions for batch_size {}'.format(batch_size))
    fig_y.suptitle('y-sliced patchwise predictions for batch_size {}'.format(batch_size))
    fig_z.suptitle('z-sliced patchwise predictions for batch_size {}'.format(batch_size))
    plt.show()
Exemplo n.º 6
0
def train(fixed_patches, moving_patches, epoch, model, criterion, optimizer,
          weight, device):
    r"""Training the model
        Args:
            fixed_patches (Tensor): Tensor holding the fixed_patches ([num_patches, 1, patch_size, patch_size, patch_size])
            moving_patches (Tensor): Tensor holding the moving patches ([num_patches, 1, patch_size, patch_size, patch_size])
            epoch (int): current epoch
            model (nn.Module): Network model
            criterion (nn.Module): Loss-function
            optimizer (optim.Optimizer): optimizer in which to optimise the network
            weight (float): float number to weight the regularizer in the loss function
        Returns:
            array of training losses over each batch
    """

    train_set = CreateDataset(fixed_patches, moving_patches)
    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=0,
                              pin_memory=True,
                              drop_last=True)

    training_loss = torch.zeros(len(train_loader), device=device)

    for batch_idx, (fixed_batch, moving_batch) in enumerate(train_loader):

        fixed_batch, moving_batch = fixed_batch.to(device), moving_batch.to(
            device)

        optimizer.zero_grad()

        predicted_theta = model(fixed_batch, moving_batch)
        predicted_deform = affine_transform(moving_batch, predicted_theta)

        # Loss is complete loss function with regularization. cross_corr = 1 - NCC
        loss, cross_corr = criterion(fixed_batch,
                                     predicted_deform,
                                     predicted_theta,
                                     weight,
                                     reduction='mean')

        if args.precision == 'amp' and not apexImportError:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        optimizer.step()

        training_loss[batch_idx] = cross_corr.item()

        if args.register_hook:
            get_hook(model)

        printer = progress_printer(batch_idx / len(train_loader))
        print(printer + ' Training epoch {:2}/{} (steps: {})'.format(
            epoch + 1, args.epochs, len(train_loader)),
              end='\r',
              flush=True)

    return training_loss