Ejemplo n.º 1
0
def Show_patches(X, Y):
    '''
  Visualize patches

  Parameters
  ----------
  X : (np.ndarray) array of patches
  Y : (np.ndarray) array of patches

  Returns
  ---------
  void

  '''
    #plot of training patches.
    plt.figure(figsize=(12, 5))
    plot_some(X[:5], Y[:5])
    plt.suptitle(
        '5 example training patches (top row: source, bottom row: target)')

    #plot of validation patches
    plt.figure(figsize=(12, 5))
    plot_some(X_val[:5], Y_val[:5])
    plt.suptitle(
        '5 example validation patches (top row: source, bottom row: target)')
Ejemplo n.º 2
0
def run_z(y, x, models):
    # Prep data
    # model prediction and normalizations output float64
    x = np.array(x, dtype='float64')
    y = np.array(y, dtype='float64')
    axes = 'ZYX'

    # Prepare the dict for metrics
    d = {
        'output': [],
        'columns': ['output', 'rmse', 'ssim'],
        'id_vars': ['output'],
        'var_name': 'metric'
    }

    # Define comparisons
    def get_output(name, y, x):
        return [name, np.sqrt(mse(x, y)), ssim(x, y)]

    # Normalize GT dynamic range to enable comparisons w/ numbers
    yn = util.percentile_norm(y, axes)

    # Get the comparison for normalized input im
    xn = util.percentile_norm(x, axes)
    d['output'].append(get_output('input', yn, xn))

    predictions = {'models': ['input', 'N(GT)'], 'ims': [xn, yn]}
    for m in models:
        model = CARE(config=None, name=m, basedir='models')
        restored = model.predict_probabilistic(x, axes, n_tiles=(1, 4, 4))
        pred = restored.mean()
        y_pred_n = util.percentile_norm(pred, axes)

        d['output'].append(get_output(m, yn, y_pred_n))
        predictions['models'].append(m)
        predictions['ims'].append(y_pred_n)

    # Plot a random stack
    zix = util.get_randint_ixs(1, len(y))
    ims = [[im[zix] for im in predictions['ims']]]
    plt.figure(figsize=(16, 10))
    plot_some(np.stack(ims), title_list=[predictions['models']])
    plt.show()

    # Costruct a df for the barplot
    df = pd.DataFrame(
        d['output'],
        columns=d['columns'],
    )
    df = pd.melt(df, id_vars=d['id_vars'], var_name=d['var_name'])

    g = sns.catplot(x='metric',
                    y='value',
                    hue='output',
                    kind='bar',
                    sharey=False,
                    data=df)
    g.ax.set_ylim(0, 1)
    plt.show()
Ejemplo n.º 3
0
    def train(self, channels=None, **config_args):
        #limit_gpu_memory(fraction=1)
        if channels is None:
            channels = self.train_channels

        for ch in channels:
            print("-- Training channel {}...".format(ch))
            (X, Y), (X_val, Y_val), axes = load_training_data(
                self.get_training_patch_path() /
                'CH_{}_training_patches.npz'.format(ch),
                validation_split=0.1,
                verbose=False)

            c = axes_dict(axes)['C']
            n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            train_epochs=self.train_epochs,
                            train_steps_per_epoch=self.train_steps_per_epoch,
                            train_batch_size=self.train_batch_size,
                            **config_args)
            # Training
            model = CARE(config,
                         'CH_{}_model'.format(ch),
                         basedir=pathlib.Path(self.out_dir) / 'models')

            # Show learning curve and example validation results
            try:
                history = model.train(X, Y, validation_data=(X_val, Y_val))
            except tf.errors.ResourceExhaustedError:
                print(
                    "ResourceExhaustedError: Aborting...\n Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?"
                )
                return

            #print(sorted(list(history.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(history, ['loss', 'val_loss'],
                         ['mse', 'val_mse', 'mae', 'val_mae'])

            plt.figure(figsize=(12, 7))
            _P = model.keras_model.predict(X_val[:5])

            plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray")
            plt.suptitle('5 example validation patches\n'
                         'top row: input (source),  '
                         'middle row: target (ground truth),  '
                         'bottom row: predicted from source')

            plt.show()

            print("-- Export model for use in Fiji...")
            model.export_TF()
            print("Done")
Ejemplo n.º 4
0
def showPlot(X, Y, XY_axes):
    for i in range(3):
        plt.figure(figsize=(8, 4))
        sl = slice(5 * i, 5 * (i + 1)), 0
        plot_some(X[sl],
                  Y[sl],
                  title_list=[np.arange(sl[0].start, sl[0].stop)])  #X puis Y
        plt.suptitle(
            '5 example validation patches (top row: source, bottom row: target)'
        )
        plt.show()
Ejemplo n.º 5
0
    def create_patches(self, normalization=None):
        for ch in self.train_channels:
            n_images = len(
                list((pathlib.Path(self.out_dir) / "train_data" / "raw" /
                      "CH_{}".format(ch) / "GT").glob("*.tif")))
            print("-- Creating {} patches for channel: {}".format(
                n_images * self.n_patches_per_image, ch))
            raw_data = RawData.from_folder(
                basepath=pathlib.Path(self.out_dir) / "train_data" / "raw" /
                "CH_{}".format(ch),
                source_dirs=["low"],
                target_dir="GT",
                axes=self.axes,
            )

            if normalization is not None:
                X, Y, XY_axes = create_patches(
                    raw_data=raw_data,
                    patch_size=self.patch_size,
                    n_patches_per_image=self.n_patches_per_image,
                    save_file=self.get_training_patch_path() /
                    "CH_{}_training_patches.npz".format(ch),
                    verbose=False,
                    normalization=normalization,
                )
            else:

                X, Y, XY_axes = create_patches(
                    raw_data=raw_data,
                    patch_size=self.patch_size,
                    n_patches_per_image=self.n_patches_per_image,
                    save_file=self.get_training_patch_path() /
                    "CH_{}_training_patches.npz".format(ch),
                    verbose=False,
                )

            plt.figure(figsize=(16, 4))

            rand_sel = numpy.random.randint(low=0, high=len(X), size=6)
            plot_some(X[rand_sel, 0],
                      Y[rand_sel, 0],
                      title_list=[range(6)],
                      cmap="gray")

            plt.show()

        print("Done")
        return
Ejemplo n.º 6
0
def plot_three(X,
               restored,
               Y,
               ix=None,
               use_ix=True,
               take_restored_mean=True,
               figsize=(16, 10)):
    if use_ix:
        x = X[ix, ..., 0]
        y = Y[ix, ..., 0]
        titles = [['input %d' % ix, 'prediction', 'target']]
    else:
        x = X
        y = Y
        titles = [['input', 'prediction', 'target']]
    pred = restored.mean() if take_restored_mean else restored
    ims = [[x, pred, y]]
    plt.figure(figsize=figsize)
    plot_some(np.stack(ims), title_list=titles)
Ejemplo n.º 7
0
ModelName = 'WingVeinUNET'
load_path = BaseDir + NPZdata

# In[3]:

(X, Y), (X_val, Y_val), axes = load_training_data(load_path,
                                                  validation_split=0.05,
                                                  verbose=True)

c = axes_dict(axes)['C']
n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

# In[4]:

plt.figure(figsize=(12, 5))
plot_some(X_val[:5], Y_val[:5])
plt.suptitle(
    '5 example validation patches (top row: source, bottom row: target)')

# In[5]:

config = config = Config(axes,
                         n_channel_in,
                         n_channel_out,
                         probabilistic=False,
                         unet_n_depth=5,
                         unet_n_first=48,
                         unet_kern_size=7,
                         train_loss='mae',
                         train_epochs=150,
                         train_learning_rate=1.0E-4,
Ejemplo n.º 8
0
    #   "--output_filename",
    #   help="the name of the out"
    # )
    parser.add_argument("-p", "--plot", action="store_true", default=True)
    args = parser.parse_args()

    # Load the trained model
    model = CARE(config=None, name=args.model_name, basedir=args.model_dir)

    # Read the test images
    x = imread(
        os.path.join(args.base_dir, args.path_to_test_image, "low_snr.tif"))
    y = imread(
        os.path.join(args.base_dir, args.path_to_test_image, "high_snr.tif"))

    restored = model.predict(x, axes="YX")  # , n_tiles=(1,4,4))

    # Save the restored image
    save_tiff_imagej_compatible(
        os.path.join(args.base_dir, args.path_to_test_image, "predicted.tif"),
        restored,
        axes="YX")

    # Plot the restored image next to the test pair
    if args.plot:
        plt.figure(figsize=(16, 10))
        plot_some(
            np.stack([x, restored, y]),
            title_list=[['low res', 'CARE', 'target']])
        plt.show()
Ejemplo n.º 9
0
modelName = "modelRulerSpectralXCentered"  #without noise
#modelName = "modelBeadsLittleNoiseSpectra64_200pepoch_woBorder" #noise added

########################
# predict output image #
########################
model = CARE(config=None, name=modelName, basedir='models')
restored = model.predict(x[imageNumber], "YX", normalizer=None)

###############
# Show result #
###############
plt.figure(figsize=(16, 10))
plot_some(np.stack([x[imageNumber], restored]),
          title_list=[['source image', 'predicted (CARE)']],
          pmin=2,
          pmax=99.8)

plt.show()

#########################
# Predict probabilistic #
#########################
restored_prob = model.predict_probabilistic(x[imageNumber],
                                            "YX",
                                            normalizer=None)  #axes?
plt.figure(figsize=(16, 10))
plot_some(np.stack([restored_prob.mean(),
                    restored_prob.scale()]),
          title_list=[['mean', 'scale']])
plt.show()
                                   channel_low=args.channel_low)
    print(f"{number_of_images} image pairs found")

    # Take a random one of the image pairs, hide as a final test image,
    # then optionally plot
    img_number = np.random.randint(0, number_of_images + 1)

    if args.plot:
        x = imread(
            os.path.join(args.base_dir, "low_snr", f"img_{img_number}.tif"))
        y = imread(
            os.path.join(args.base_dir, "high_snr", f"img_{img_number}.tif"))
        print(f"image size: {x.shape}")
        plt.figure(figsize=(16, 10))
        plot_some(
            np.stack([x, y]),
            title_list=[['low snr', 'high snr']],
        )
        plt.show()

    if not os.path.exists(os.path.join(args.base_dir, "test")):
        os.mkdir(os.path.join(args.base_dir, "test"))
    shutil.move(
        os.path.join(args.base_dir, "low_snr", f"img_{img_number}.tif"),
        os.path.join(args.base_dir, "test", "low_snr.tif"))
    shutil.move(
        os.path.join(args.base_dir, "high_snr", f"img_{img_number}.tif"),
        os.path.join(args.base_dir, "test", "high_snr.tif"))

    # Read the pairs, passing in the axis semantics
    raw_data = RawData.from_folder(basepath=args.base_dir,
                                   source_dirs=['low_snr'],
Ejemplo n.º 11
0
# make subset data to only analyze specific channels: will use the channel names supplied at command line
print('using previously saved channel names for subsetting')
chans = np.load(base_dir + 'chan_names.npy')

keepers = input_channels
keep_idx = np.isin(chans, keepers)
if np.sum(keep_idx) == 0:
    raise ValueError("Did not supply valid channel name")

print('analyzing the following channels: {}'.format(chans[keep_idx]))


x_train, x_test = x_train[:, :, :, keep_idx], x_test[:, :, :, keep_idx]
y_train, y_test = y_train[:, :, :, keep_idx], y_test[:, :, :, keep_idx]


# this code taken directly from FAQ, uses internal functions to do plotting
fig = plt.figure(figsize=(30,30))
_P = model.keras_model.predict(x_test[:5, :, :, :])
_P_mean  = _P[...,:(_P.shape[-1]//2)]
_P_scale = _P[...,(_P.shape[-1]//2):]
plot_some(x_test[:5, :, :, 0],y_test[:5, :, :, :],_P_mean,_P_scale,pmax=99.5)
fig.suptitle('5 example validation patches\n'      
             'first row: input (source),  '        
             'second row: target (ground truth),  '
             'third row: predicted Laplace mean,  '
             'forth row: predicted Laplace scale');
fig.savefig('/models/' + model_name + '.pdf')

Ejemplo n.º 12
0
#################
#10% of validation data are used there.
(X_train,
 Y_train), (X_val,
            Y_val), axes = load_training_data(folderName + '/' + filename,
                                              validation_split=validationSplit,
                                              verbose=True)
#(X_train, Y_train), (X_val,Y_val), axes = load_training_data('data/synthetic_disks/data.npz', validation_split=0.1, verbose=True)

print("axes : ", axes)

c = axes_dict(axes)['C']
n_channel_in, n_channel_out = X_train.shape[c], Y_train.shape[c]

plt.figure(figsize=(12, 5))
plot_some(X_val[:5], Y_val[:5])
plt.suptitle(
    '5 example validation patches (top row: source, bottom row: target)')

#################
# Configuration #
#################

# Config object contains: parameters of the underlying neural network, learning rate, number of parameter updates per epoch, loss function, and whether the model is probabilistic or not.

config = Config(axes,
                n_channel_in,
                n_channel_out,
                probabilistic=True,
                train_steps_per_epoch=stepPerEpoch)
print(config)
Ejemplo n.º 13
0
def plot_four(x, pred, y, yn, figsize=(16, 10)):
    titles = [['input', 'prediction', 'N(GT)', 'GT']]
    ims = [[x, pred, yn, y]]
    plt.figure(figsize=figsize)
    plot_some(np.stack(ims), title_list=titles)
Ejemplo n.º 14
0
    def train(self, channels=None, **config_args):
        # limit_gpu_memory(fraction=1)
        if channels is None:
            channels = self.train_channels

        with Timer("Training"):

            for ch in channels:
                print("-- Training channel {}...".format(ch))
                (X, Y), (X_val, Y_val), axes = load_training_data(
                    self.get_training_patch_path() /
                    "CH_{}_training_patches.npz".format(ch),
                    validation_split=0.1,
                    verbose=False,
                )

                c = axes_dict(axes)["C"]
                n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

                config = Config(
                    axes,
                    n_channel_in,
                    n_channel_out,
                    train_epochs=self.train_epochs,
                    train_steps_per_epoch=self.train_steps_per_epoch,
                    train_batch_size=self.train_batch_size,
                    probabilistic=self.probabilistic,
                    **config_args,
                )
                # Training

                # if (
                #     pathlib.Path(self.out_dir) / "models" / "CH_{}_model".format(ch)
                # ).exists():
                #     print("config there already")
                #     config = None

                model = CARE(
                    config,
                    "CH_{}_model".format(ch),
                    basedir=pathlib.Path(self.out_dir) / "models",
                )

                # Show learning curve and example validation results
                try:
                    history = model.train(X, Y, validation_data=(X_val, Y_val))
                except tf.errors.ResourceExhaustedError:
                    print(
                        " >> ResourceExhaustedError: Aborting...\n  Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?"
                    )
                    return
                except tf.errors.UnknownError:
                    print(
                        " >> UnknownError: Aborting...\n  No enough memory available on GPU... are other GPU jobs running?"
                    )
                    return

                # print(sorted(list(history.history.keys())))
                plt.figure(figsize=(16, 5))
                plot_history(history, ["loss", "val_loss"],
                             ["mse", "val_mse", "mae", "val_mae"])

                plt.figure(figsize=(12, 7))
                _P = model.keras_model.predict(X_val[:5])

                if self.probabilistic:
                    _P = _P[..., 0]

                plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray")
                plt.suptitle("5 example validation patches\n"
                             "top row: input (source),  "
                             "middle row: target (ground truth),  "
                             "bottom row: predicted from source")

                plt.show()

                print("-- Export model for use in Fiji...")
                model.export_TF()
                print("Done")