示例#1
0
def main():
    config = load_config()

    # make sure those values correspond to the ones used during training
    in_channels = config['in_channels']
    out_channels = config['out_channels']
    # use the same upsampling as during training
    interpolate = config['interpolate']
    # specify the layer ordering used during training
    layer_order = config['layer_order']
    # should sigmoid be used as a final activation layer
    final_sigmoid = config['final_sigmoid']
    init_channel_number = config['init_channel_number']
    model = UNet3D(in_channels,
                   out_channels,
                   init_channel_number=init_channel_number,
                   final_sigmoid=final_sigmoid,
                   interpolate=interpolate,
                   conv_layer_order=layer_order)

    model_path = config['model_path']
    logger.info(f'Loading model from {model_path}...')
    utils.load_checkpoint(model_path, model)

    device = config['device']
    model = model.to(device)

    logger.info('Loading datasets...')
    for test_dataset in get_test_datasets(config):
        # run the model prediction on the entire dataset
        probability_maps = predict(model, test_dataset, out_channels, device)
        # save the resulting probability maps
        output_file = _get_output_file(test_dataset)
        save_predictions(probability_maps, output_file)
示例#2
0
def main():
    # Load configuration
    config = load_config()

    # create logger
    logfile = config.get('logfile', None)
    logger = utils.get_logger('UNet3DPredictor', logfile=logfile)

    # Create the model
    model = get_model(config)

    # multiple GPUs
    if (torch.cuda.device_count() > 1):
        logger.info("There are {} GPUs available".format(
            torch.cuda.device_count()))
        model = nn.DataParallel(model)

    # Load model state
    model_path = config['model_path']
    logger.info(f'Loading model from {model_path}...')
    utils.load_checkpoint(model_path, model)
    logger.info(f"Sending the model to '{config['device']}'")
    model = model.to(config['device'])

    logger.info('Loading HDF5 datasets...')
    for test_loader in get_test_loaders(config):
        logger.info(f"Processing '{test_loader.dataset.file_path}'...")

        #output_file = _get_output_file(test_loader.dataset)
        output_file = _get_output_file(config['output_folder'],
                                       test_loader.dataset)
        logger.info(output_file)
        predictor = _get_predictor(model, test_loader, output_file, config)
        # run the model prediction on the entire dataset and save to the 'output_file' H5
        predictor.predict()
示例#3
0
def main():
    # Load configuration
    config = load_config()

    # Create the model
    model = get_model(config)

    # Load model state
    model_path = config['model_path']
    logger.info(f'Loading model from {model_path}...')
    utils.load_checkpoint(model_path, model)
    logger.info(f"Sending the model to '{config['device']}'")
    model = model.to(config['device'])

    logger.info('Loading HDF5 datasets...')
    store_predictions_in_memory = config.get('store_predictions_in_memory',
                                             True)
    if store_predictions_in_memory:
        logger.info(
            'Predictions will be stored in memory. Make sure you have enough RAM for you dataset.'
        )

    for test_loader in get_test_loaders(config):
        logger.info(f"Processing '{test_loader.dataset.file_path}'...")

        output_file = _get_output_file(test_loader.dataset)
        # run the model prediction on the entire dataset and save to the 'output_file' H5
        if store_predictions_in_memory:
            predict_in_memory(model, test_loader, output_file, config)
        else:
            predict(model, test_loader, output_file, config)
示例#4
0
def main():
    def _get_output_file(dataset):
        volume_file, volume_ext = os.path.splitext(dataset.path)
        return volume_file + '_probabilities' + volume_ext

    parser = argparse.ArgumentParser()
    parser.add_argument('--model-path', type=str, help='path to the model')
    parser.add_argument('--config-path', type=str,
                        help='path to the dataset config')
    parser.add_argument('--in-channels', default=1, type=int,
                        help='number of input channels')
    parser.add_argument('--out-channels', default=6, type=int,
                        help='number of output channels')
    parser.add_argument('--layer-order', type=str,
                        help="Conv layer ordering, e.g. 'brc' -> BatchNorm3d+ReLU+Conv3D",
                        default='brc')
    parser.add_argument('--interpolate',
                        help='use F.interpolate instead of ConvTranspose3d',
                        action='store_true')

    args = parser.parse_args()

    # make sure those values correspond to the ones used during training
    in_channels = args.in_channels
    out_channels = args.out_channels
    # use F.interpolate for upsampling
    interpolate = args.interpolate
    # Conv layer ordering e.g. 'cr' is equivalent to Conv3D+ReLU
    conv_layer_order = args.layer_order
    model = UNet3D(in_channels, out_channels, interpolate=interpolate,
                   final_sigmoid=True, conv_layer_order=conv_layer_order)

    logger.info(f'Loading model from {args.model_path}...')
    utils.load_checkpoint(args.model_path, model)

    logger.info('Loading datasets...')
    raw_volumes = get_raw_volumes(args.config_path)

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        logger.warning(
            'No CUDA device available. Using CPU for predictions')
        device = torch.device('cpu')

    model = model.to(device)

    for raw_volume in raw_volumes:
        probability_maps = predict(model, raw_volume, device)

        output_file = _get_output_file(raw_volume)

        save_predictions(probability_maps, output_file)
示例#5
0
def load_model(checkpointpath):
    model = InstantiatedModel
    state = utils.load_checkpoint(checkpointpath, model, map_location='cuda:0')
    model.cuda()

    model.eval()
    return model
示例#6
0
def main():
    parser = argparse.ArgumentParser(description='3D U-Net predictions')
    parser.add_argument('--model-path', required=True, type=str,
                        help='path to the model')
    parser.add_argument('--in-channels', required=True, type=int,
                        help='number of input channels')
    parser.add_argument('--out-channels', required=True, type=int,
                        help='number of output channels')
    parser.add_argument('--interpolate',
                        help='use F.interpolate instead of ConvTranspose3d',
                        action='store_true')
    parser.add_argument('--layer-order', type=str,
                        help="Conv layer ordering, e.g. 'brc' -> BatchNorm3d+ReLU+Conv3D",
                        default='brc')

    args = parser.parse_args()

    # make sure those values correspond to the ones used during training
    in_channels = args.in_channels
    out_channels = args.out_channels
    # use F.interpolate for upsampling
    interpolate = args.interpolate
    layer_order = args.layer_order
    model = UNet3D(in_channels, out_channels, interpolate=interpolate,
                   final_sigmoid=True, conv_layer_order=layer_order)

    logger.info(f'Loading model from {args.model_path}...')
    utils.load_checkpoint(args.model_path, model)

    logger.info('Loading datasets...')

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        logger.warning(
            'No CUDA device available. Using CPU for predictions')
        device = torch.device('cpu')

    model = model.to(device)

    dataset, dataset_size = _get_dataset(1)
    probability_maps = predict(model, dataset, dataset_size, device)

    output_file = os.path.join(os.path.split(args.model_path)[0],
                               'probabilities.h5')

    save_predictions(probability_maps, output_file)
示例#7
0
def main():
    config = parse_test_config()

    # make sure those values correspond to the ones used during training
    in_channels = config.in_channels
    out_channels = config.out_channels
    # use F.interpolate for upsampling
    interpolate = config.interpolate
    layer_order = config.layer_order
    final_sigmoid = config.final_sigmoid
    model = UNet3D(in_channels,
                   out_channels,
                   init_channel_number=config.init_channel_number,
                   final_sigmoid=final_sigmoid,
                   interpolate=interpolate,
                   conv_layer_order=layer_order)

    logger.info(f'Loading model from {config.model_path}...')
    utils.load_checkpoint(config.model_path, model)

    logger.info('Loading datasets...')

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        logger.warning('No CUDA device available. Using CPU for prediction...')
        device = torch.device('cpu')

    model = model.to(device)

    patch = tuple(config.patch)
    stride = tuple(config.stride)

    for test_path in config.test_path:
        # create dataset for a given test file
        dataset = HDF5Dataset(test_path,
                              patch,
                              stride,
                              phase='test',
                              raw_internal_path=config.raw_internal_path)
        # run the model prediction on the entire dataset
        probability_maps = predict(model, dataset, out_channels, device)
        # save the resulting probability maps
        output_file = f'{os.path.splitext(test_path)[0]}_probabilities.h5'
        save_predictions(probability_maps, output_file)
示例#8
0
def main():
    # Load configuration
    config = load_config()

    # Create the model
    model = get_model(config)

    # Load model state
    model_path = config['model_path']
    logger.info(f'Loading model from {model_path}...')
    utils.load_checkpoint(model_path, model)
    logger.info(f"Sending the model to '{config['device']}'")
    model = model.to(config['device'])

    logger.info('Loading HDF5 datasets...')
    for test_loader in get_test_loaders(config):
        logger.info(f"Processing '{test_loader.dataset.file_path}'...")

        output_file = _get_output_file(test_loader.dataset)
        # run the model prediction on the entire dataset and save to the 'output_file' H5
        predict(model, test_loader, output_file, config)
示例#9
0
def main():
    # Load configuration
    config = load_config()

    # Create the model
    model = get_model(config)

    # Load model state
    model_path = config['model_path']
    logger.info(f'Loading model from {model_path}...')
    utils.load_checkpoint(model_path, model)
    logger.info(f"Sending the model to '{config['device']}'")
    model = model.to(config['device'])

    logger.info('Loading HDF5 datasets...')

    test_loader = get_test_loaders(config)['test']
    for i, data_pair in enumerate(test_loader):
        output_file = 'predict_' + str(i) + '.h5'
        predictor = _get_predictor(model, data_pair, output_file, config)
        predictor.predict()
示例#10
0
def main():
    # Load configuration
    config = load_config()

    # Create the model
    model = get_model(config)

    # Load model state
    model_path = config['model_path']
    logger.info(f'Loading model from {model_path}...')
    utils.load_checkpoint(model_path, model)
    model = model.to(config['device'])

    logger.info('Loading HDF5 datasets...')
    for test_dataset in get_test_datasets(config):
        logger.info(f"Processing '{test_dataset.file_path}'...")
        # run the model prediction on the entire dataset
        predictions = predict(model, test_dataset, config)
        # save the resulting probability maps
        output_file = _get_output_file(test_dataset)
        dataset_names = _get_dataset_names(config, len(predictions))
        save_predictions(predictions, output_file, dataset_names)
def main():
	# Load configuration
	config = load_config()

	# Create the model
	model = get_model(config)

	# Create evaluation metric
	eval_criterion = get_evaluation_metric(config)

	# Load model state
	model_path = config['model_path']
	logger.info(f'Loading model from {model_path}...')
	utils.load_checkpoint(model_path, model)
	model = model.to(config['device'])

	logger.info('Loading HDF5 datasets...')
	
	# ========================== for data batch, score recording ==========================	
	nii_path="/data/cephfs/punim0877/liver_segmentation_v1/Test_Batch"  # load path of test batch
	hdf5_path="./resources/hdf5ed_test_data" # create dir to save predict image
	stage=1
	
	for index in range(110,131):		# delete for loop. only need one file 
		if not hdf5_it(nii_path,hdf5_path,index,stage):
			continue
		config["datasets"]["test_path"]=[]
		for hdf5_file in os.listdir(hdf5_path):
			print("adding %s to trainging list" % (hdf5_file))
			config["datasets"]["test_path"].append(os.path.join(hdf5_path,hdf5_file))
		
		for test_dataset in get_test_datasets(config):
			logger.info(f"Processing '{test_dataset.file_path}'...")
			# run the model prediction on the entire dataset
			predictions = predict(model, test_dataset, config, eval_criterion)
			# save the resulting probability maps
			output_file = _get_output_file(test_dataset)
			dataset_names = _get_dataset_names(config, len(predictions))
			save_predictions(predictions, output_file, dataset_names)
def get_job_name():
    now = '{:%Y-%m-%d.%H:%M}'.format(datetime.datetime.now())
    return "%s_model" % (now)


logger = utils.get_logger('UNet3DPredictor')

# Load and log experiment configuration
config = load_config()

# Load model state
model = get_model(config)
model_path = config['trainer']['test_model']
logger.info(f'Loading model from {model_path}...')
utils.load_checkpoint(model_path, model)

# Run on GPU or CPU
# if torch.cuda.is_available():
#     print("using cuda (", torch.cuda.device_count(), "device(s))")
#     if torch.cuda.device_count() > 1:
#         model = nn.DataParallel(model)
#     device = torch.device("cuda:1")
# else:
#     device = torch.device("cpu")
#     print("using cpu")
# model = model.to(device)
logger.info(f"Sending the model to '{config['device']}'")
model = model.to('cuda:0')

predictionsBasePath = config['loaders']['pred_path']
示例#13
0
def main():
    parser = argparse.ArgumentParser(description='3D U-Net predictions')
    parser.add_argument('--cdmodel-path', required=True, type=str,
                        help='path to the coordinate detector model.')
    parser.add_argument('--model-path', required=True, type=str,
                        help='path to the segmentation model')
    parser.add_argument('--in-channels', type=int, default=1,
                        help='number of input channels (default: 1)')
    parser.add_argument('--out-channels', type=int, default=2,
                        help='number of output channels (default: 2)')
    parser.add_argument('--init-channel-number', type=int, default=64,
                        help='Initial number of feature maps in the encoder path which gets doubled on every stage (default: 64)')
    parser.add_argument('--layer-order', type=str,
                        help="Conv layer ordering, e.g. 'crg' -> Conv3D+ReLU+GroupNorm",
                        default='crg')
    parser.add_argument('--final-sigmoid',
                        action='store_true',
                        help='if True apply element-wise nn.Sigmoid after the last layer otherwise apply nn.Softmax')
    parser.add_argument('--test-path', type=str, nargs='+', required=True, help='path to the test dataset')
    parser.add_argument('--raw-internal-path', type=str, default='raw')
    parser.add_argument('--patch', type=int, nargs='+', default=None,
                        help='Patch shape for used for prediction on the test set')
    parser.add_argument('--stride', type=int, nargs='+', default=None,
                        help='Patch stride for used for prediction on the test set')
    parser.add_argument('--report-metrics', action='store_true',
                        help='Whether to print metrics for each prediction')
    parser.add_argument('--output-path', type=str, default='./output/',
                        help='The output path to generate the nifti file')

    args = parser.parse_args()

    # Check if output path exists
    if not os.path.isdir(args.output_path):
        os.mkdir(args.output_path)

    # make sure those values correspond to the ones used during training
    in_channels = args.in_channels
    out_channels = args.out_channels
    # use F.interpolate for upsampling
    interpolate = True
    layer_order = args.layer_order
    final_sigmoid = args.final_sigmoid

    # Define model
    UNet_model = UNet3D(in_channels, out_channels,
                       init_channel_number=args.init_channel_number,
                       final_sigmoid=final_sigmoid,
                       interpolate=interpolate,
                       conv_layer_order=layer_order)
    Coor_model = CoorNet(in_channels)

    # Define metrics
    loss = nn.MSELoss(reduction='sum')
    acc = DiceCoefficient()
    
    logger.info('Loading trained coordinate detector model from ' + args.cdmodel_path)
    utils.load_checkpoint(args.cdmodel_path, Coor_model)

    logger.info('Loading trained segmentation model from ' + args.model_path)
    utils.load_checkpoint(args.model_path, UNet_model)

    # Load the model to the device
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        logger.warning('No CUDA device available. Using CPU for predictions')
        device = torch.device('cpu')
    UNet_model = UNet_model.to(device)
    Coor_model = Coor_model.to(device)

    # Apply patch training if assigned
    if args.patch and args.stride:
        patch = tuple(args.patch)
        stride = tuple(args.stride)

    # Initialise counters
    total_dice = 0
    total_loss = 0
    count = 0
    tmp_created = False

    for test_path in args.test_path:
        if test_path.endswith('.nii.gz'):
            if args.report_metrics:
                raise ValueError("Cannot report metrics on original files.")
            # Temporary save as h5 file
            # Preprocess if dim != 192 x 224 x 192
            data = preprocess_nifti(test_path, args.output_path)
            logger.info('Preprocessing complete.')
            hf = h5py.File(test_path + '.h5', 'w')
            hf.create_dataset('raw', data=data)
            hf.close()
            test_path += '.h5'
            tmp_created = True
        if not args.patch and not args.stride:
            curr_shape = np.array(h5py.File(test_path, 'r')[args.raw_internal_path]).shape
            patch = curr_shape
            stride = curr_shape

        # Initialise dataset
        dataset = HDF5Dataset(test_path, patch, stride, phase='test', raw_internal_path=args.raw_internal_path)        

        file_name = test_path.split('/')[-1].split('.')[0]
        # Predict the centre coordinates
        x, y, z = predict(Coor_model, dataset, out_channels, device)

        # Perform segmentation
        probability_maps = predict(UNet_model, dataset, out_channels, device, x, y, z)
        res = np.argmax(probability_maps, axis=0)

        # Put the image batch back to mask with the original size
        res = recover_patch(res, x, y, z, dataset.raw.shape)

        # Extract LH and RH segmentations and write as file
        LH = np.zeros(res.shape)
        LH[int(res.shape[0]/2):,:,:] = res[int(res.shape[0]/2):,:,:]
        RH = np.zeros(res.shape)
        RH[:int(res.shape[0]/2),:,:] = res[:int(res.shape[0]/2),:,:]
        
        LH_img = nib.Nifti1Image(LH, AFF)
        RH_img = nib.Nifti1Image(RH, AFF)
        nib.save(LH_img, args.output_path + file_name + '_LH.nii.gz')
        nib.save(RH_img, args.output_path + file_name + '_RH.nii.gz')
        logger.info('File saved to ' + args.output_path + file_name + '_LH.nii.gz')
        logger.info('File saved to ' + args.output_path + file_name + '_RH.nii.gz')
        
        if tmp_created:
            os.remove(test_path)

        if args.report_metrics:
            count += 1

            # Compute coordinate accuracy
            # Coordinate evaluation disabled by default, since not all data have coordinate information
            # coor_dataset = HDF5Dataset(test_path, patch, stride, phase='val', raw_internal_path=args.raw_internal_path, label_internal_path='coor')
            # coor_target = coor_dataset[0][1].to(device)
            # coor_pred_tensor = torch.from_numpy(np.array([x, y, z])).to(device)
            # curr_coor_loss = loss(coor_pred_tensor, coor_target)
            # total_loss += curr_coor_loss
            # logger.info('Current coordinate loss: %f' % (curr_coor_loss))

            # Compute segmentation Dice score
            label_dataset = HDF5Dataset(test_path, patch, stride, phase='val', raw_internal_path=args.raw_internal_path, label_internal_path='label')
            label_target = label_dataset[0][1].to(device)
            res_dice = probability_maps
            new_shape = np.append(res_dice.shape[0], np.array(label_target.size()))
            res_dice = recover_patch_4d(res_dice, x, y, z, new_shape)
            pred_tensor = torch.from_numpy(res_dice).to(device).float()
            label_target = label_target.view((1,) + label_target.shape)
            curr_dice_score = acc(pred_tensor, label_target.long())
            total_dice += curr_dice_score
            logger.info('Current Dice score: %f' % (curr_dice_score))

            # Compute length estimation
            logger.info('RH length: ' + str(get_total_dist(res[:int(res.shape[0]/2),:,:])))
            logger.info('LH length: ' + str(get_total_dist(res[int(res.shape[0]/2):,:,:])))
    
    if args.report_metrics:       
        # logger.info('Average loss: %f.' % (total_loss/count))
        logger.info('Average Dice score: %f.' % (total_dice/count))
示例#14
0
def main():
    parser = argparse.ArgumentParser(description='3D U-Net predictions')
    parser.add_argument('--model-path',
                        required=True,
                        type=str,
                        help='path to the model')
    parser.add_argument('--in-channels',
                        required=True,
                        type=int,
                        help='number of input channels')
    parser.add_argument('--out-channels',
                        required=True,
                        type=int,
                        help='number of output channels')
    parser.add_argument(
        '--init-channel-number',
        type=int,
        default=64,
        help=
        'Initial number of feature maps in the encoder path which gets doubled on every stage (default: 64)'
    )
    parser.add_argument('--interpolate',
                        help='use F.interpolate instead of ConvTranspose3d',
                        action='store_true')
    parser.add_argument(
        '--layer-order',
        type=str,
        help="Conv layer ordering, e.g. 'crg' -> Conv3D+ReLU+GroupNorm",
        default='crg')
    parser.add_argument(
        '--final-sigmoid',
        action='store_true',
        help=
        'if True apply element-wise nn.Sigmoid after the last layer otherwise apply nn.Softmax'
    )
    parser.add_argument('--test-path',
                        type=str,
                        required=True,
                        help='path to the test dataset')
    parser.add_argument('--raw-internal-path', type=str, default='raw')
    parser.add_argument(
        '--patch',
        required=True,
        type=int,
        nargs='+',
        default=None,
        help='Patch shape for used for prediction on the test set')
    parser.add_argument(
        '--stride',
        required=True,
        type=int,
        nargs='+',
        default=None,
        help='Patch stride for used for prediction on the test set')

    args = parser.parse_args()

    # make sure those values correspond to the ones used during training
    in_channels = args.in_channels
    out_channels = args.out_channels
    # use F.interpolate for upsampling
    interpolate = args.interpolate
    layer_order = args.layer_order
    final_sigmoid = args.final_sigmoid
    model = UNet3D(in_channels,
                   out_channels,
                   init_channel_number=args.init_channel_number,
                   final_sigmoid=final_sigmoid,
                   interpolate=interpolate,
                   conv_layer_order=layer_order)

    logger.info(f'Loading model from {args.model_path}...')
    utils.load_checkpoint(args.model_path, model)

    logger.info('Loading datasets...')

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        logger.warning('No CUDA device available. Using CPU for predictions')
        device = torch.device('cpu')

    model = model.to(device)

    patch = tuple(args.patch)
    stride = tuple(args.stride)

    dataset = HDF5Dataset(args.test_path,
                          patch,
                          stride,
                          phase='test',
                          raw_internal_path=args.raw_internal_path)
    probability_maps = predict(model, dataset, out_channels, device)

    output_file = f'{os.path.splitext(args.test_path)[0]}_probabilities.h5'

    # average channels only in case of final_sigmoid
    save_predictions(probability_maps, output_file, final_sigmoid)
示例#15
0
def main():
    parser = argparse.ArgumentParser(description='3D U-Net predictions')
    parser.add_argument('--model-path',
                        required=True,
                        type=str,
                        help='path to the model')
    parser.add_argument('--in-channels',
                        required=True,
                        type=int,
                        help='number of input channels')
    parser.add_argument('--out-channels',
                        required=True,
                        type=int,
                        help='number of output channels')
    parser.add_argument('--interpolate',
                        help='use F.interpolate instead of ConvTranspose3d',
                        action='store_true')
    parser.add_argument(
        '--average-channels',
        help=
        'average the probability_maps across the the channel axis (use only if your channels refer to the same semantic class)',
        action='store_true')
    parser.add_argument(
        '--layer-order',
        type=str,
        help="Conv layer ordering, e.g. 'crg' -> Conv3D+ReLU+GroupNorm",
        default='crg')
    parser.add_argument(
        '--loss',
        type=str,
        required=True,
        help=
        'Loss function used for training. Possible values: [ce, bce, wce, dice]. Has to be provided cause loss determines the final activation of the model.'
    )
    parser.add_argument('--test-path',
                        type=str,
                        required=True,
                        help='path to the test dataset')
    parser.add_argument(
        '--patch',
        required=True,
        type=int,
        nargs='+',
        default=None,
        help='Patch shape for used for prediction on the test set')
    parser.add_argument(
        '--stride',
        required=True,
        type=int,
        nargs='+',
        default=None,
        help='Patch stride for used for prediction on the test set')

    args = parser.parse_args()

    # make sure those values correspond to the ones used during training
    in_channels = args.in_channels
    out_channels = args.out_channels
    # use F.interpolate for upsampling
    interpolate = args.interpolate
    layer_order = args.layer_order
    final_sigmoid = _final_sigmoid(args.loss)
    model = UNet3D(in_channels,
                   out_channels,
                   final_sigmoid=final_sigmoid,
                   interpolate=interpolate,
                   conv_layer_order=layer_order)

    logger.info(f'Loading model from {args.model_path}...')
    utils.load_checkpoint(args.model_path, model)

    logger.info('Loading datasets...')

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        logger.warning('No CUDA device available. Using CPU for predictions')
        device = torch.device('cpu')

    model = model.to(device)

    patch = tuple(args.patch)
    stride = tuple(args.stride)

    dataset = HDF5Dataset(args.test_path, patch, stride, phase='test')
    probability_maps = predict(model, dataset, dataset.raw.shape, device)

    output_file = f'{os.path.splitext(args.test_path)[0]}_probabilities.h5'

    save_predictions(probability_maps, output_file, args.average_channels)