Esempio n. 1
0
    def test_load_and_save(self):
        """Check the `images.load_image` and `images.save_image` functions."""

        img = np.random.randint(128, size=(4, 6))

        # Make a temporary directory to store fits files
        with tempfile.TemporaryDirectory() as temp_dir_path:

            img_path = os.path.join(temp_dir_path, "test.fits")

            # Save the image
            images.save_image(img, img_path)

            # Load the saved image
            loaded_img = images.load_image(img_path)

            # Check img vs loaded_img
            np.testing.assert_array_equal(img, loaded_img)
Esempio n. 2
0
    def test_load_and_save_with_nan(self):
        """Check the `images.load_image` and `images.save_image` functions."""

        img = np.random.uniform(size=(4, 6))
        img[1, 1] = np.nan

        # Make a temporary directory to store fits files
        with tempfile.TemporaryDirectory() as temp_dir_path:

            img_path = os.path.join(temp_dir_path, "test.fits")

            # Save the image
            images.save_image(img, img_path)

            # Load the saved image
            loaded_img = images.load_image(img_path)

            # Check img vs loaded_img
            np.testing.assert_array_equal(img, loaded_img)

            # Check the NaN pixel value is kept
            self.assertTrue(np.isnan(loaded_img[1, 1]))
Esempio n. 3
0
    def test_load_wrong_hdu_error(self):
        """Check the call to `images.load_image` fails with a WrongHDUError
        when trying to load a FITS image from an HDU index that does not exist."""

        img = np.random.randint(128, size=(3, 3))  # Make a 2D image

        # Make a temporary directory to store fits files
        with tempfile.TemporaryDirectory() as temp_dir_path:

            img_path = os.path.join(temp_dir_path, "test.fits")

            # Save the image
            images.save_image(img, img_path)

            # Load the saved image (should raise an exception)
            with self.assertRaises(WrongHDUError):
                loaded_img = images.load_image(img_path, hdu_index=1000)

            # Load the saved image (should not raise any exception)
            try:
                loaded_img = images.load_image(img_path, hdu_index=0)
            except WrongHDUError:
                self.fail(
                    "images.load_image() raised WrongHDUError unexpectedly!")
Esempio n. 4
0
    def test_save_wrong_dimension_error(self):
        """Check the call to `images.load_image` fails with an WrongDimensionError
        when saved images have more than 3 dimensions or less than 2
        dimensions."""

        img_1d = np.random.randint(128, size=(3))  # Make a 1D image
        img_2d = np.random.randint(128, size=(3, 3))  # Make a 2D image
        img_3d = np.random.randint(128, size=(3, 3, 3))  # Make a 3D image
        img_4d = np.random.randint(128, size=(3, 3, 3, 3))  # Make a 4D image

        # Make a temporary directory to store fits files
        with tempfile.TemporaryDirectory() as temp_dir_path:

            img_path = os.path.join(temp_dir_path, "test.fits")

            # Save the 1D image (should raise an exception)
            with self.assertRaises(WrongDimensionError):
                images.save_image(img_1d, img_path)

            # Save the 2D image (should not raise any exception)
            try:
                images.save_image(img_2d, img_path)
            except WrongDimensionError:
                self.fail(
                    "images.save() raised WrongDimensionError unexpectedly!")

            # Save the 3D image (should not raise any exception)
            try:
                images.save_image(img_3d, img_path)
            except WrongDimensionError:
                self.fail(
                    "images.save() raised WrongDimensionError unexpectedly!")

            # Save the 4D image (should raise an exception)
            with self.assertRaises(WrongDimensionError):
                images.save_image(img_4d, img_path)
Esempio n. 5
0
def main():
    """The main module execution function.

    Contains the instructions executed when the module is not imported but
    directly called from the system command line.
    """

    # PARSE OPTIONS ###########################################################

    parser = argparse.ArgumentParser(
        description="Filter images with Starlet Transform.")

    parser = add_arguments(parser)
    parser = add_common_arguments(parser)

    args = parser.parse_args()

    type_of_filtering = args.type_of_filtering
    filter_thresholds_str = args.filter_thresholds
    last_scale_treatment = args.last_scale
    detect_only_positive_structures = args.detect_only_positive_structures
    remove_isolated_pixels = args.remove_isolated_pixels
    noise_cdf_file = args.noise_cdf_file

    verbose = args.verbose
    debug = args.debug
    #    max_images = args.max_images
    #    benchmark_method = args.benchmark
    #    label = args.label
    plot = args.plot
    saveplot = args.saveplot

    #    input_file_or_dir_path_list = args.fileargs
    input_file = args.fileargs[0]

    # CHECK OPTIONS #############################

    if type_of_filtering not in hard_filter.AVAILABLE_TYPE_OF_FILTERING:
        raise ValueError(
            'Unknown type of filterning: "{}". Should be in {}'.format(
                type_of_filtering, hard_filter.AVAILABLE_TYPE_OF_FILTERING))

    try:
        filter_thresholds = [
            float(threshold_str)
            for threshold_str in filter_thresholds_str.split(",")
        ]
    except:
        raise ValueError(
            'Wrong filter thresholds: "{}". Should be in a list of figures separated by a comma (e.g. "3,2,3")'
            .format(filter_thresholds_str))

    if last_scale_treatment not in mrtransform_wrapper.AVAILABLE_LAST_SCALE_OPTIONS:
        raise ValueError(
            'Unknown type of last scale treatment: "{}". Should be in {}'.
            format(last_scale_treatment,
                   mrtransform_wrapper.AVAILABLE_LAST_SCALE_OPTIONS))

    # TODO: check the noise_cdf_file value

    #############################################

    #if args.output is None:
    #    output_file_path = "score_wavelets_benchmark_{}.json".format(benchmark_method)
    #else:
    #    output_file_path = args.output

    ##if noise_cdf_file is not None:
    ##    noise_distribution = EmpiricalDistribution(noise_cdf_file)
    ##else:
    ##    noise_distribution = None
    noise_distribution = None

    # CLEAN THE INPUT IMAGE ###################################

    input_img = load_image(input_file)

    # Copy the image (otherwise some cleaning functions may change it)
    input_img_copy = input_img.astype('float64', copy=True)

    cleaned_img = clean_image(
        input_img_copy,
        type_of_filtering=type_of_filtering,
        filter_thresholds=filter_thresholds,
        last_scale_treatment=last_scale_treatment,
        detect_only_positive_structures=detect_only_positive_structures,
        kill_isolated_pixels=remove_isolated_pixels,
        noise_distribution=noise_distribution)

    # PLOT IMAGES #########################################################

    if plot or (saveplot is not None):
        image_list = [input_img, cleaned_img]
        title_list = ["Input image", "Filtered image"]

        if plot:
            plot_list(image_list, title_list=title_list)

        if saveplot is not None:
            plot_file_path = saveplot
            print("Saving {}".format(plot_file_path))
            mpl_save_list(image_list,
                          output_file_path=plot_file_path,
                          title_list=title_list)

    # SAVE IMAGE ##########################################################

    basename, extension = os.path.splitext(input_file)
    output_file_path = "{}-out{}".format(basename, extension)
    save_image(cleaned_img, output_file_path)