예제 #1
0
파일: train.py 프로젝트: ninatubau/3D-RCAN
def load_data(config, data_type):
    image_pair_list = config.get(data_type + '_image_pairs', [])

    if data_type + '_data_dir' in config:
        raw_dir, gt_dir = [
            pathlib.Path(config[data_type + '_data_dir'][t])
            for t in ['raw', 'gt']
        ]

        raw_files, gt_files = [
            sorted(d.glob('*.tif')) for d in [raw_dir, gt_dir]
        ]

        if not raw_files:
            raise RuntimeError(f'No TIFF file found in {raw_dir}')

        if len(raw_files) != len(gt_files):
            raise RuntimeError(
                f'"{raw_dir}" and "{gt_dir}" must contain the same number of '
                'TIFF files')

        for raw_file, gt_file in zip(raw_files, gt_files):
            image_pair_list.append({'raw': str(raw_file), 'gt': str(gt_file)})

    if not image_pair_list:
        return None

    print(f'Loading {data_type} data')

    data = []
    for p in image_pair_list:
        raw_file, gt_file = [p[t] for t in ['raw', 'gt']]

        print('  - raw:', raw_file)
        print('    gt:', gt_file)

        raw, gt = [tifffile.imread(p[t]) for t in ['raw', 'gt']]

        if raw.shape != gt.shape:
            raise ValueError(
                'Raw and GT images must be the same size: '
                f'{p["raw"]} {raw.shape} vs. {p["gt"]} {gt.shape}')

        data.append([normalize(m) for m in [raw, gt]])

    return data
예제 #2
0
파일: apply.py 프로젝트: eguomin/3D-RCAN
if args.block_overlap_shape is None:
    overlap_shape = [
        max(1, x // 8) if x > 2 else 0
        for x in model.input.shape.as_list()[1:-1]
    ]
else:
    overlap_shape = args.block_overlap_shape

if args.scale_value is None:
    sValue = 2000
else:
    sValue = args.scale_value

for raw_file, gt_file in data:
    print('Loading raw image from', raw_file)
    raw = normalize(tifffile.imread(str(raw_file)))

    print('Applying model')
    restored = apply(model, raw, overlap_shape=overlap_shape, verbose=True)

    # result = [raw, restored]
    result = restored  # Min: save only DL recovered image

    if gt_file is not None:
        print('Loading ground truth image from', gt_file)
        gt = tifffile.imread(str(gt_file))
        if raw.shape == gt.shape:
            result.append(normalize(gt))
        else:
            print('Ground truth image discarded due to image shape mismatch')
예제 #3
0
if args.block_overlap_shape is None:
    overlap_shape = [
        max(1, x // 8) if x > 2 else 0
        for x in model.input.shape.as_list()[1:-1]]
else:
    overlap_shape = args.block_overlap_shape

if args.scale_value is None:
    sValue = 2000
else:
    sValue = args.scale_value

for raw_file, gt_file in data:
    print('Loading raw image from', raw_file)
    raw = normalize(tifffile.imread(str(raw_file)), args.p_min, args.p_max)

    print('Applying model')
    restored = apply(model, raw, overlap_shape=overlap_shape, verbose=True)

    # result = [raw, restored]
    result = restored # Min: save only DL recovered image

    if gt_file is not None:
        print('Loading ground truth image from', gt_file)
        gt = tifffile.imread(str(gt_file))
        if raw.shape == gt.shape:
            gt = normalize(gt, args.p_min, args.p_max)
            if args.rescale:
                restored = rescale(restored, gt)
            result = [raw, restored, gt]
예제 #4
0
        elif config['step1_output_bit'] == 16:
            result = np.clip(result, 0, 65535).astype('uint16')

        if input_interp_path.is_dir():
            output_file = input_interp_path / raw_file.name
        else:
            output_file = input_interp_path

        print('Saving step1 output image to', output_file)
        tifffile.imwrite(str(output_file), result, imagej=False)

    # # # Step 1: de-aberration
    if config['step1_trigger']:

        print('Applying step1 model')
        raw = normalize(raw)
        restored = apply(step1_model,
                         raw,
                         overlap_shape=step1_overlap_shape,
                         verbose=True)
        restored[restored < 0] = 0

        if config['step1_output_trigger']:
            result = np.stack(restored)
            if result.ndim == 4:
                result = np.transpose(result, (1, 0, 2, 3))

            if config['step1_output_bit'] == 8:
                result = np.clip(sValue * result, 0, 255).astype('uint8')
            elif config['step1_output_bit'] == 16:
                result = np.clip(sValue * result, 0, 65535).astype('uint16')