def gen_dataset(args, train=True):
    pyro.clear_param_store()
    images, labels, _ = load_morphomnist_like(args.data_dir, train=train)

    if args.digit_class is not None:
        mask = (labels == args.digit_class)
        images = images[mask]
        labels = labels[mask]

    n_samples = len(images)
    with torch.no_grad():
        thickness, width = model(n_samples, scale=args.scale)

    metrics = pd.DataFrame(data={'thickness': thickness, 'width': width})

    for n, (thickness, width) in enumerate(tqdm(zip(thickness, width), total=n_samples)):
        morph = ImageMorphology(images[n], scale=16)
        tmp_img = morph.downscale(np.float32(SetWidth(width)(morph)))

        tmp_morph = ImageMorphology(tmp_img, scale=16)
        tmp_img = np.float32(SetThickness(thickness)(tmp_morph))

        images[n] = morph.downscale(tmp_img)

    # TODO: do we want to save the sampled or the measured metrics?

    save_morphomnist_like(images, labels, metrics, args.out_dir, train=train)
def gen_dataset(args, train=True):
    pyro.clear_param_store()
    images, labels, _ = load_morphomnist_like(args.data_dir, train=train)
    mask = (labels == args.digit_class)
    images = images[mask]
    labels = labels[mask]

    n_samples = len(images)
    with torch.no_grad():
        slant, thickness = model(n_samples)

    metrics = pd.DataFrame(data={'thickness': thickness, 'slant': slant})

    for n, (slant, thickness) in enumerate(
            tqdm(zip(slant, thickness), total=n_samples)):
        morph = ImageMorphology(images[n], scale=16)
        tmp_img = np.float32(SetThickness(thickness)(morph))
        tmp_morph = ImageMorphology(tmp_img, scale=1)
        tmp_img = np.float32(SetSlant(np.deg2rad(slant))(tmp_morph))
        images[n] = morph.downscale(tmp_img)

    save_morphomnist_like(images, labels, metrics, args.out_dir, train=train)
예제 #3
0
def gen_dataset(args, train=True):
    pyro.clear_param_store()
    images_, labels, _ = load_morphomnist_like(args.data_dir, train=train)

    if args.digit_class is not None:
        mask = (labels == args.digit_class)
        images_ = images_[mask]
        labels = labels[mask]

    images = np.zeros_like(images_)

    n_samples = len(images)
    with torch.no_grad():
        thickness, intensity = model(n_samples,
                                     scale=args.scale,
                                     invert=args.invert)

    metrics = pd.DataFrame(data={
        'thickness': thickness,
        'intensity': intensity
    })

    for n, (thickness, intensity) in enumerate(
            tqdm(zip(thickness, intensity), total=n_samples)):
        morph = ImageMorphology(images_[n], scale=16)
        tmp_img = morph.downscale(np.float32(SetThickness(thickness)(morph)))

        avg_intensity = get_intensity(tmp_img)

        mult = intensity.numpy() / avg_intensity
        tmp_img = np.clip(tmp_img * mult, 0, 255)

        images[n] = tmp_img

    # TODO: do we want to save the sampled or the measured metrics?

    save_morphomnist_like(images, labels, metrics, args.out_dir, train=train)
예제 #4
0
def subsample(filter_fn, source_dir, train: bool, target_dir):
    images, labels, metrics = load_morphomnist_like(source_dir, train,
                                                    source_dir)
    idx = filter_fn(metrics)
    save_morphomnist_like(images[idx], labels[idx], metrics[idx], train,
                          target_dir)
예제 #5
0
    # target_thickness = 6 if label % 2 == 0 else 1.5
    # trf = transforms.SetThickness(target_thickness)
    target_thickness = max(1, 2.5 * np.exp(2 * metrics.slant))
    trf = transforms.SetThickness(target_thickness)
    trf_image = morph.downscale(trf(morph))
    trf_metrics = measure_image(trf_image, scale=4, verbose=False)
    return trf_image, label, trf_metrics


if __name__ == '__main__':
    import os

    root_dir = "/vol/biomedic/users/dc315/mnist"
    source_dir = os.path.join(root_dir, "original")
    target_dir = os.path.join(root_dir, "sub_th3_sl0")
    images, labels, metrics = load_morphomnist_like(source_dir, train=False)
    # with multiprocessing.Pool() as pool:
    pool = None
    nrow, ncol = 9, 12
    trf_images, trf_labels, trf_metrics = apply_conditional_transformation(
        example_fn, images[:nrow * ncol], labels, metrics, pool)

    import matplotlib.pyplot as plt

    fig, axs = plt.subplots(nrow, ncol)
    for i, ax in enumerate(axs.flat):
        if i >= len(trf_images):
            break
        ax.imshow(trf_images[i], cmap='gray_r')
        # ax.set_title(f"label: {trf_labels[i]}")
        ax.axis('off')