Beispiel #1
0
def run_z(y, x, models):
    # Prep data
    # model prediction and normalizations output float64
    x = np.array(x, dtype='float64')
    y = np.array(y, dtype='float64')
    axes = 'ZYX'

    # Prepare the dict for metrics
    d = {
        'output': [],
        'columns': ['output', 'rmse', 'ssim'],
        'id_vars': ['output'],
        'var_name': 'metric'
    }

    # Define comparisons
    def get_output(name, y, x):
        return [name, np.sqrt(mse(x, y)), ssim(x, y)]

    # Normalize GT dynamic range to enable comparisons w/ numbers
    yn = util.percentile_norm(y, axes)

    # Get the comparison for normalized input im
    xn = util.percentile_norm(x, axes)
    d['output'].append(get_output('input', yn, xn))

    predictions = {'models': ['input', 'N(GT)'], 'ims': [xn, yn]}
    for m in models:
        model = CARE(config=None, name=m, basedir='models')
        restored = model.predict_probabilistic(x, axes, n_tiles=(1, 4, 4))
        pred = restored.mean()
        y_pred_n = util.percentile_norm(pred, axes)

        d['output'].append(get_output(m, yn, y_pred_n))
        predictions['models'].append(m)
        predictions['ims'].append(y_pred_n)

    # Plot a random stack
    zix = util.get_randint_ixs(1, len(y))
    ims = [[im[zix] for im in predictions['ims']]]
    plt.figure(figsize=(16, 10))
    plot_some(np.stack(ims), title_list=[predictions['models']])
    plt.show()

    # Costruct a df for the barplot
    df = pd.DataFrame(
        d['output'],
        columns=d['columns'],
    )
    df = pd.melt(df, id_vars=d['id_vars'], var_name=d['var_name'])

    g = sns.catplot(x='metric',
                    y='value',
                    hue='output',
                    kind='bar',
                    sharey=False,
                    data=df)
    g.ax.set_ylim(0, 1)
    plt.show()
Beispiel #2
0
###############
# Show result #
###############
plt.figure(figsize=(16, 10))
plot_some(np.stack([x[imageNumber], restored]),
          title_list=[['source image', 'predicted (CARE)']],
          pmin=2,
          pmax=99.8)

plt.show()

#########################
# Predict probabilistic #
#########################
restored_prob = model.predict_probabilistic(x[imageNumber],
                                            "YX",
                                            normalizer=None)  #axes?
plt.figure(figsize=(16, 10))
plot_some(np.stack([restored_prob.mean(),
                    restored_prob.scale()]),
          title_list=[['mean', 'scale']])
plt.show()

######################
# Save the two files #
######################
imsave("restoredImage.tif", restored)
imsave("notRestoredImage.tif", x[imageNumber])

# CSBDeep has a save_tiff_imagej_compatible function that can be used also (TODO: study the difference)
# Sometimes, the images appears grey on ImageJ. It may be because of that.
Beispiel #3
0
    def predict(self, file_fn, n_tiles=(1, 4, 4), keep_meta=True):
        JVM().start()

        pixel_reso = get_space_time_resolution(file_fn)
        print("Prediction {}".format(file_fn))
        print(" -- Using pixel sizes and frame interval", pixel_reso)

        ir = bf.ImageReader(file_fn)
        reader = ir.rdr

        loci_pixel_type = reader.getPixelType()

        if loci_pixel_type == 1:
            # uint8
            dtype = numpy.uint8
        elif loci_pixel_type == 3:
            # uint16
            dtype = numpy.uint16
        else:
            print(
                "Error: Pixel-type not supported. Pixel type must be 8- or 16-bit"
            )
            return

        if self.probabilistic:
            dtype = numpy.float32

        series = 0
        z_size = reader.getSizeZ()
        y_size = reader.getSizeY()
        x_size = reader.getSizeX()
        c_size = reader.getSizeC()
        t_size = reader.getSizeT()

        z_out_size = int(z_size * self.low_scaling[0])
        y_out_size = int(y_size * self.low_scaling[1])
        x_out_size = int(x_size * self.low_scaling[2])

        if c_size != len(self.train_channels):
            print(
                " -- Warning: Number of Channels during training and prediction do not match. Using channels {} for prediction"
                .format(self.train_channels))

        for ch in self.train_channels:
            model = CARE(
                None,
                "CH_{}_model".format(ch),
                basedir=pathlib.Path(self.out_dir) / "models",
            )

            out_channels = 1
            if self.probabilistic:
                out_channels = 2

            res_image_ch = numpy.zeros(
                shape=(t_size, z_out_size, out_channels, y_out_size,
                       x_out_size),
                dtype=dtype,
            )

            print(" -- Predicting channel {}".format(ch))
            for t in tqdm(range(t_size), total=t_size):
                img_3d = numpy.zeros((z_size, y_size, x_size), dtype=dtype)
                for z in range(z_size):
                    img_3d[z, :, :] = ir.read(series=series,
                                              z=z,
                                              c=ch,
                                              t=t,
                                              rescale=False)

                img_3d_ch_ex = rescale(
                    img_3d,
                    self.low_scaling,
                    preserve_range=True,
                    order=self.order,
                    multichannel=False,
                    mode="reflect",
                    anti_aliasing=True,
                )

                if not self.probabilistic:
                    # non-probabilistic
                    pred = model.predict(img_3d_ch_ex,
                                         axes="ZYX",
                                         n_tiles=n_tiles)
                    di = numpy.iinfo(dtype)
                    pred = pred.clip(di.min, di.max).astype(dtype)
                    res_image_ch[t, :, 0, :, :] = pred
                else:
                    # probabilistic
                    pred = model.predict_probabilistic(img_3d_ch_ex,
                                                       axes="ZYX",
                                                       n_tiles=n_tiles)
                    di = numpy.float32

                    res_image_ch[t, :, 0, :, :] = pred.mean()
                    res_image_ch[t, :, 1, :, :] = pred.scale()

                if False:
                    ch_t_out_fn = os.path.join(
                        os.path.dirname(file_fn),
                        os.path.splitext(os.path.basename(file_fn))[0] +
                        "_care_predict_tp{:04d}_ch{}.tif".format(t, ch),
                    )
                    print("Saving time-point {} and channel {} to file '{}'".
                          format(t, ch, ch_t_out_fn))
                    tifffile.imsave(
                        ch_t_out_fn,
                        pred[None, :, None, :, :],
                        imagej=True,
                        metadata={"axes": "TZCYX"},
                    )

            ch_out_fn = os.path.join(
                os.path.dirname(file_fn),
                os.path.splitext(os.path.basename(file_fn))[0] +
                "_care_predict_ch{}.tif".format(ch),
            )
            print(" -- Saving channel {} CARE prediction to file '{}'".format(
                ch, ch_out_fn))

            if keep_meta:
                reso = (
                    1 / (pixel_reso.X / self.low_scaling[2]),
                    1 / (pixel_reso.Y / self.low_scaling[1]),
                )
                spacing = pixel_reso.Z / self.low_scaling[0]
                unit = pixel_reso.Xunit
                finterval = pixel_reso.T

                tifffile.imsave(
                    ch_out_fn,
                    res_image_ch,
                    imagej=True,
                    resolution=reso,
                    metadata={
                        "axes": "TZCYX",
                        "finterval": finterval,
                        "spacing": spacing,
                        "unit": unit,
                    },
                )
            else:
                tifffile.imsave(ch_out_fn, res_image_ch)

            res_image_ch = None  # should trigger gc and free the memory