예제 #1
0
    def test_get_all_attributes(self):
        port = InputPort('images', self.storage)

        assert port.get_all_static_attributes() == {'PIXSCALE': 0.01}
        assert port.get_all_non_static_attributes() == [
            'PARANG',
        ]
예제 #2
0
    def add_input_port(self, tag):
        """
        Function which creates an InputPort for a ProcessingModule and appends it to the internal
        InputPort dictionary. This function should be used by classes inheriting from
        ProcessingModule to make sure that only input ports with unique tags are added. The new
        port can be used as: ::

             port = self._m_input_ports[tag]

        or by using the returned Port.

        :param tag: Tag of the new input port.
        :type tag: str

        :return: The new InputPort for the ProcessingModule.
        :rtype: InputPort
        """

        port = InputPort(tag)

        if self._m_data_base is not None:
            port.set_database_connection(self._m_data_base)

        self._m_input_ports[tag] = port

        return port
예제 #3
0
    def test_create_instance_access_data(self):
        port = InputPort("images", self.storage)

        assert np.allclose(port[0, 0, 0],
                           0.00032486907273264834,
                           rtol=limit,
                           atol=0.)
        assert np.allclose(np.mean(port.get_all()),
                           1.0506056979365338e-06,
                           rtol=limit,
                           atol=0.)

        arr_tmp = np.asarray((0.00032486907273264834, -2.4494781298462809e-05,
                              -0.00038631277795631806),
                             dtype=np.float64)
        assert np.allclose(port[0:3, 0, 0], arr_tmp, rtol=limit, atol=0.)

        assert len(port[0:2, 0, 0]) == 2
        assert port.get_shape() == (10, 100, 100)

        assert port.get_attribute("PIXSCALE") == 0.01
        assert port.get_attribute("PARANG")[0] == 1

        with pytest.warns(UserWarning):
            assert port.get_attribute("none") is None
예제 #4
0
    def test_create_instance(self):
        active_port = OutputPort("test", self.storage, activate_init=True)
        deactive_port = OutputPort("test", self.storage, activate_init=False)
        control_port = InputPort("test", self.storage)

        deactive_port.open_port()
        deactive_port.set_all(np.asarray([0, 1, 2, 3]))
        deactive_port.flush()

        # raises warning
        assert control_port.get_all() is None

        active_port.set_all(np.asarray([0, 1, 2, 3]))
        active_port.flush()

        assert np.array_equal(np.asarray([0, 1, 2, 3]), control_port.get_all())

        active_port.del_all_data()
예제 #5
0
    def __init__(self, working_pypeline):

        super(ImageWrapper, self).__init__(working_pypeline)

        # needed for data export (we want to get rid of the identification numbers used for the
        # image instances
        self._m_tag_root_image = "im_arr"
        self._m_tag_root_mask_image = "im_mask_arr"
        self._m_tag_root_mask = "im_cent_mask"
        self._m_tag_root_psf_image_arr = "psf_im_arr"

        # In the old PynPoint it was possible to create multiple image instances working on
        # separated data (in memory). Hence, every time a new ImageWrapper is created a new database
        # entry is required. (Using increasing identification numbers)
        self._m_image_data_tag = self._m_tag_root_image + str(
            ImageWrapper.class_counter).zfill(2)
        self._m_image_data_port = InputPort(self._m_image_data_tag)
        self._m_image_data_port.set_database_connection(
            working_pypeline.m_data_storage)

        self._m_image_data_masked_tag = self._m_tag_root_mask_image + \
                                        str(ImageWrapper.class_counter).zfill(2)
        self._m_image_data_masked_port = InputPort(
            self._m_image_data_masked_tag)
        self._m_image_data_masked_port.set_database_connection(
            working_pypeline.m_data_storage)

        self._m_mask_tag = self._m_tag_root_mask + str(
            ImageWrapper.class_counter).zfill(2)
        self._m_mask_port = InputPort(self._m_mask_tag)
        self._m_mask_port.set_database_connection(
            working_pypeline.m_data_storage)

        self._m_psf_image_arr_tag = self._m_tag_root_psf_image_arr + \
                                    str(ImageWrapper.class_counter).zfill(2)
        self._m_psf_image_arr_port = InputPort(self._m_psf_image_arr_tag)
        self._m_psf_image_arr_port.set_database_connection(
            working_pypeline.m_data_storage)

        self._m_restore_tag_dict = {
            self._m_tag_root_image: self._m_image_data_tag,
            self._m_tag_root_mask_image: self._m_image_data_masked_tag,
            self._m_tag_root_mask: self._m_mask_tag,
            self._m_tag_root_psf_image_arr: self._m_psf_image_arr_tag
        }

        self._m_save_tag_dict = {
            self._m_image_data_tag: self._m_tag_root_image,
            self._m_image_data_masked_tag: self._m_tag_root_mask_image,
            self._m_mask_tag: self._m_tag_root_mask,
            self._m_psf_image_arr_tag: self._m_tag_root_psf_image_arr
        }

        ImageWrapper.class_counter += 1
예제 #6
0
    def test_create_instance_access_non_existing_data(self):
        port = InputPort("test", self.storage)

        with pytest.warns(UserWarning):
            assert port[0, 0, 0] is None

        with pytest.warns(UserWarning):
            assert port.get_all() is None

        with pytest.warns(UserWarning):
            assert port.get_shape() is None

        with pytest.warns(UserWarning):
            assert port.get_attribute("num_files") is None

        with pytest.warns(UserWarning):
            assert port.get_all_non_static_attributes() is None

        with pytest.warns(UserWarning):
            assert port.get_all_static_attributes() is None
예제 #7
0
    def test_create_instance_no_data_storage(self):
        port = InputPort("test")

        with pytest.warns(UserWarning):
            assert port[0, 0, 0] is None

        with pytest.warns(UserWarning):
            assert port.get_all() is None

        with pytest.warns(UserWarning):
            assert port.get_shape() is None

        with pytest.warns(UserWarning):
            assert port.get_all_non_static_attributes() is None

        with pytest.warns(UserWarning):
            assert port.get_all_static_attributes() is None
예제 #8
0
class ResidualsWrapper(object):
    class_counter = 1

    def __init__(self, working_pypeline):
        self._pypeline = working_pypeline

        self._m_res_arr_root = "res_arr"
        self._m_res_rot_root = "res_rot"
        self._m_res_mean_root = "res_mean"
        self._m_res_median_root = "res_median"
        self._m_res_var_root = "res_var"
        self._m_res_rot_mean_clip_root = "res_rot_mean_clip"

        self._m_res_arr = self._m_res_arr_root + str(
            ResidualsWrapper.class_counter).zfill(2)
        self._m_res_arr_port = InputPort(self._m_res_arr)
        self._m_res_arr_port.set_database_connection(
            self._pypeline.m_data_storage)

        self._m_res_rot = self._m_res_rot_root + str(
            ResidualsWrapper.class_counter).zfill(2)
        self._m_res_rot_port = InputPort(self._m_res_rot)
        self._m_res_rot_port.set_database_connection(
            self._pypeline.m_data_storage)

        self._m_res_mean = self._m_res_mean_root + str(
            ResidualsWrapper.class_counter).zfill(2)
        self._m_res_mean_port = InputPort(self._m_res_mean)
        self._m_res_mean_port.set_database_connection(
            self._pypeline.m_data_storage)

        self._m_res_median = self._m_res_median_root + str(
            ResidualsWrapper.class_counter).zfill(2)
        self._m_res_median_port = InputPort(self._m_res_median)
        self._m_res_median_port.set_database_connection(
            self._pypeline.m_data_storage)

        self._m_res_var = self._m_res_var_root + str(
            ResidualsWrapper.class_counter).zfill(2)
        self._m_res_var_port = InputPort(self._m_res_var)
        self._m_res_var_port.set_database_connection(
            self._pypeline.m_data_storage)

        self._m_res_rot_mean_clip = self._m_res_rot_mean_clip_root \
                                    + str(ResidualsWrapper.class_counter).zfill(2)
        self._m_res_rot_mean_clip_port = InputPort(self._m_res_rot_mean_clip)
        self._m_res_rot_mean_clip_port.set_database_connection(
            self._pypeline.m_data_storage)

        ResidualsWrapper.class_counter += 1

        self._m_basis = None
        self._m_images = None

    def __getattr__(self, item):
        data_bases = {
            "im_arr": self._m_images.im_arr,
            "cent_mask": self._m_images.cent_mask,
            "im_arr_mask": self._m_images.im_arr_mask,
            "psf_im_arr": self._m_images.psf_im_arr
        }

        if item in data_bases:
            return data_bases[item]

    @classmethod
    def create_restore(cls, filename):

        image = ImageWrapper.create_restore(filename)
        basis = BasisWrapper.create_restore(filename)

        tmp_pypeline = image._pypeline

        obj = cls(tmp_pypeline)

        obj._m_basis = basis
        obj._m_images = image

        return obj

    def save(self, filename):
        # save image
        self._m_images.save(filename)

        # save basis
        self._m_basis.save(filename)

    @classmethod
    def create_winstances(cls, images, basis):

        tmp_pypeline = images._pypeline

        obj = cls(tmp_pypeline)

        obj._m_basis = basis
        obj._m_images = images

        # Input Ports to return results

        assert np.array_equal(basis.cent_mask, images.cent_mask)
        assert np.array_equal(basis.psf_basis[0, ].shape,
                              images.im_arr[0, ].shape)

        return obj

    def _mk_result(self, extra_rot_in=0.0):

        if "res_module" in self._pypeline._m_modules:
            return

        res_module = CreateResidualsModule(
            name_in="res_module",
            im_arr_in_tag=self._m_images._m_image_data_tag,
            psf_im_in_tag=self._m_images._m_psf_image_arr_tag,
            mask_in_tag=self._m_images._m_mask_tag,
            res_arr_out_tag=self._m_res_arr,
            res_arr_rot_out_tag=self._m_res_rot,
            res_mean_tag=self._m_res_mean,
            res_median_tag=self._m_res_median,
            res_var_tag=self._m_res_var,
            res_rot_mean_clip_tag=self._m_res_rot_mean_clip,
            extra_rot=extra_rot_in)
        self._pypeline.add_module(res_module)
        self._pypeline.run_module("res_module")

    def res_arr(self, num_coeff):

        # check if psf image array was calculated
        if self._m_images.psf_im_arr is None:
            self.mk_psfmodel(num_coeff)

        self._mk_result()

        return self._m_res_arr_port.get_all()

    def res_rot(self, num_coeff, extra_rot=0.0):
        # check if psf image array was calculated
        if self._m_images.psf_im_arr is None:
            self.mk_psfmodel(num_coeff)

        self._mk_result(extra_rot)

        return self._m_res_rot_port.get_all()

    def res_rot_mean(self, num_coeff, extra_rot=0.0):
        self.res_rot(num_coeff=num_coeff, extra_rot=extra_rot)

        return self._m_res_mean_port.get_all()

    def res_rot_median(self, num_coeff, extra_rot=0.0):
        self.res_rot(num_coeff=num_coeff, extra_rot=extra_rot)

        return self._m_res_median_port.get_all()

    def res_rot_mean_clip(self, num_coeff, extra_rot=0.0):
        self.res_rot(num_coeff=num_coeff, extra_rot=extra_rot)

        return self._m_res_rot_mean_clip_port.get_all()

    def res_rot_var(self, num_coeff, extra_rot=0.0):
        self.res_rot(num_coeff=num_coeff, extra_rot=extra_rot)

        return self._m_res_var_port.get_all()

    def _psf_im(self, num_coeff):

        if self._m_images.psf_im_arr is None:
            self.mk_psfmodel(num_coeff)

        return self._m_images.psf_im_arr

    def mk_psfmodel(self, num):
        self._m_images.mk_psfmodel(self._m_basis, num)
예제 #9
0
    def __init__(self, working_pypeline):
        self._pypeline = working_pypeline

        self._m_res_arr_root = "res_arr"
        self._m_res_rot_root = "res_rot"
        self._m_res_mean_root = "res_mean"
        self._m_res_median_root = "res_median"
        self._m_res_var_root = "res_var"
        self._m_res_rot_mean_clip_root = "res_rot_mean_clip"

        self._m_res_arr = self._m_res_arr_root + str(
            ResidualsWrapper.class_counter).zfill(2)
        self._m_res_arr_port = InputPort(self._m_res_arr)
        self._m_res_arr_port.set_database_connection(
            self._pypeline.m_data_storage)

        self._m_res_rot = self._m_res_rot_root + str(
            ResidualsWrapper.class_counter).zfill(2)
        self._m_res_rot_port = InputPort(self._m_res_rot)
        self._m_res_rot_port.set_database_connection(
            self._pypeline.m_data_storage)

        self._m_res_mean = self._m_res_mean_root + str(
            ResidualsWrapper.class_counter).zfill(2)
        self._m_res_mean_port = InputPort(self._m_res_mean)
        self._m_res_mean_port.set_database_connection(
            self._pypeline.m_data_storage)

        self._m_res_median = self._m_res_median_root + str(
            ResidualsWrapper.class_counter).zfill(2)
        self._m_res_median_port = InputPort(self._m_res_median)
        self._m_res_median_port.set_database_connection(
            self._pypeline.m_data_storage)

        self._m_res_var = self._m_res_var_root + str(
            ResidualsWrapper.class_counter).zfill(2)
        self._m_res_var_port = InputPort(self._m_res_var)
        self._m_res_var_port.set_database_connection(
            self._pypeline.m_data_storage)

        self._m_res_rot_mean_clip = self._m_res_rot_mean_clip_root \
                                    + str(ResidualsWrapper.class_counter).zfill(2)
        self._m_res_rot_mean_clip_port = InputPort(self._m_res_rot_mean_clip)
        self._m_res_rot_mean_clip_port.set_database_connection(
            self._pypeline.m_data_storage)

        ResidualsWrapper.class_counter += 1

        self._m_basis = None
        self._m_images = None
예제 #10
0
class ImageWrapper(BasePynpointWrapper):
    def __init__(self, working_pypeline):

        super(ImageWrapper, self).__init__(working_pypeline)

        # needed for data export (we want to get rid of the identification numbers used for the
        # image instances
        self._m_tag_root_image = "im_arr"
        self._m_tag_root_mask_image = "im_mask_arr"
        self._m_tag_root_mask = "im_cent_mask"
        self._m_tag_root_psf_image_arr = "psf_im_arr"

        # In the old PynPoint it was possible to create multiple image instances working on
        # separated data (in memory). Hence, every time a new ImageWrapper is created a new database
        # entry is required. (Using increasing identification numbers)
        self._m_image_data_tag = self._m_tag_root_image + str(
            ImageWrapper.class_counter).zfill(2)
        self._m_image_data_port = InputPort(self._m_image_data_tag)
        self._m_image_data_port.set_database_connection(
            working_pypeline.m_data_storage)

        self._m_image_data_masked_tag = self._m_tag_root_mask_image + \
                                        str(ImageWrapper.class_counter).zfill(2)
        self._m_image_data_masked_port = InputPort(
            self._m_image_data_masked_tag)
        self._m_image_data_masked_port.set_database_connection(
            working_pypeline.m_data_storage)

        self._m_mask_tag = self._m_tag_root_mask + str(
            ImageWrapper.class_counter).zfill(2)
        self._m_mask_port = InputPort(self._m_mask_tag)
        self._m_mask_port.set_database_connection(
            working_pypeline.m_data_storage)

        self._m_psf_image_arr_tag = self._m_tag_root_psf_image_arr + \
                                    str(ImageWrapper.class_counter).zfill(2)
        self._m_psf_image_arr_port = InputPort(self._m_psf_image_arr_tag)
        self._m_psf_image_arr_port.set_database_connection(
            working_pypeline.m_data_storage)

        self._m_restore_tag_dict = {
            self._m_tag_root_image: self._m_image_data_tag,
            self._m_tag_root_mask_image: self._m_image_data_masked_tag,
            self._m_tag_root_mask: self._m_mask_tag,
            self._m_tag_root_psf_image_arr: self._m_psf_image_arr_tag
        }

        self._m_save_tag_dict = {
            self._m_image_data_tag: self._m_tag_root_image,
            self._m_image_data_masked_tag: self._m_tag_root_mask_image,
            self._m_mask_tag: self._m_tag_root_mask,
            self._m_psf_image_arr_tag: self._m_tag_root_psf_image_arr
        }

        ImageWrapper.class_counter += 1

    def mk_psf_realisation(self, ind, full=False):
        """
        Function for making a realization of the PSF using the data stored in the object

        :param ind: index of the image to be modeled
        :param full: if set to True then the masked region will be included
        :return: an image of the PSF model
        """

        im_temp = self.psf_im_arr[ind, ]
        if self.cent_size is not None:
            if full is True:
                im_temp = im_temp
            elif full is False:
                im_temp = im_temp * self.cent_mask

        return im_temp
예제 #11
0
    def create_input_port(self, tag_name):
        inport = InputPort(tag_name, self.storage)
        inport.open_port()

        return inport
예제 #12
0
class BasisWrapper(BasePynpointWrapper):
    def __init__(self, working_pypeline):

        super(BasisWrapper, self).__init__(working_pypeline)

        # needed for data export (we want to get rid of the identification numbers used for the
        # image instances
        self._m_tag_root_image = "basis_arr"
        self._m_tag_root_mask_image = "basis_mask_arr"
        self._m_tag_root_mask = "basis_cent_mask"
        self._m_tag_root_basis = "psf_basis"
        self._m_tag_root_im_average = "basis_im_ave"
        self._m_tag_root_psf_image_arr = "basis_psf_im_arr"

        # In the old PynPoint it was possible to create multiple image instances working on
        # separated data (in memory). Hence, every time a new ImageWrapper is created a new database
        # entry is required. (Using increasing identification numbers)
        self._m_image_data_tag = self._m_tag_root_image + str(
            BasisWrapper.class_counter).zfill(2)
        self._m_image_data_port = InputPort(self._m_image_data_tag)
        self._m_image_data_port.set_database_connection(
            working_pypeline.m_data_storage)

        self._m_image_data_masked_tag = self._m_tag_root_mask_image + \
                                        str(BasisWrapper.class_counter).zfill(2)
        self._m_image_data_masked_port = InputPort(
            self._m_image_data_masked_tag)
        self._m_image_data_masked_port.set_database_connection(
            working_pypeline.m_data_storage)

        self._m_mask_tag = self._m_tag_root_mask + str(
            BasisWrapper.class_counter).zfill(2)
        self._m_mask_port = InputPort(self._m_mask_tag)
        self._m_mask_port.set_database_connection(
            working_pypeline.m_data_storage)

        self._m_psf_image_arr_tag = self._m_tag_root_psf_image_arr + \
                                    str(BasisWrapper.class_counter).zfill(2)
        self._m_psf_image_arr_port = InputPort(self._m_psf_image_arr_tag)
        self._m_psf_image_arr_port.set_database_connection(
            working_pypeline.m_data_storage)

        # ONLY for Basis not for Image
        self._m_basis_tag = self._m_tag_root_basis \
                            + str(BasisWrapper.class_counter).zfill(2)
        self._m_basis_port = InputPort(self._m_basis_tag)
        self._m_basis_port.set_database_connection(
            working_pypeline.m_data_storage)

        self._m_im_average_tag = self._m_tag_root_im_average \
                                 + str(BasisWrapper.class_counter).zfill(2)
        self._m_im_average_port = InputPort(self._m_im_average_tag)
        self._m_im_average_port.set_database_connection(
            working_pypeline.m_data_storage)

        self._m_restore_tag_dict = {
            self._m_tag_root_image: self._m_image_data_tag,
            self._m_tag_root_mask_image: self._m_image_data_masked_tag,
            self._m_tag_root_mask: self._m_mask_tag,
            self._m_tag_root_basis: self._m_basis_tag,
            self._m_tag_root_im_average: self._m_im_average_tag,
            self._m_tag_root_psf_image_arr: self._m_psf_image_arr_tag
        }

        self._m_save_tag_dict = {
            self._m_image_data_tag: self._m_tag_root_image,
            self._m_image_data_masked_tag: self._m_tag_root_mask_image,
            self._m_mask_tag: self._m_tag_root_mask,
            self._m_basis_tag: self._m_tag_root_basis,
            self._m_im_average_tag: self._m_tag_root_im_average,
            self._m_psf_image_arr_tag: self._m_tag_root_psf_image_arr
        }

        BasisWrapper.class_counter += 1

    def __getattr__(self, item):
        res = super(BasisWrapper, self).__getattr__(item)

        if res is not None:
            return res

        data_bases = {
            "im_ave": self._m_im_average_port,
            "psf_basis": self._m_basis_port
        }

        if item in data_bases:
            return data_bases[item].get_all()

        elif item == "psf_basis_type":
            return self._m_basis_port.get_attribute("basis_type")

    @classmethod
    def create_wdir(cls, dir_in, **kwargs):

        obj = super(BasisWrapper, cls).create_wdir(dir_in, **kwargs)
        obj.mk_basis_set()
        return obj

    @classmethod
    def create_whdf5input(cls, file_in, pypline_working_place=None, **kwargs):

        obj = super(BasisWrapper,
                    cls).create_whdf5input(file_in, pypline_working_place,
                                           **kwargs)
        obj.mk_basis_set()

        return obj

    def mk_basis_set(self):

        basis_creation = MakePCABasisModule(
            name_in="basis_creation",
            im_arr_in_tag=self._m_image_data_tag,
            im_arr_out_tag=self._m_image_data_tag,
            im_average_out_tag=self._m_im_average_tag,
            basis_out_tag=self._m_basis_tag)

        self._pypeline.add_module(basis_creation)
        self._pypeline.run_module("basis_creation")

    def mk_orig(self, ind):
        pass

    def mk_psfmodel(self, num):

        # call the super function with own attributes (basis is a basis)
        super(BasisWrapper, self).mk_psfmodel(self, num)