def clean_up(self):
     """
     Closes all open fits files so they don't remain in memory.
     """
     print("Closing Fits files...")
     for hdul in self.hdul_list:
         hdul.close()
     logging_tools.log("Fits files closed successfully.")
     print("Files closed.")
    def convert_catalogue_to_metadata(self):

        if 'original_image' not in self.catalogue.columns:
            if len(self.images) > 1:
                logging_tools.log("""If multiple fits images are used the
                                  original_image column must be provided in
                                  the catalogue to identify which image the 
                                  source belongs to.""",
                                  level='ERROR')

                raise ValueError("Incorrect input supplied")

            else:
                self.catalogue['original_image'] = \
                    [list(self.images.keys())[0]] * len(self.catalogue)

        if 'objid' not in self.catalogue.columns:
            self.catalogue['objid'] = np.arange(len(self.catalogue))

        if 'peak_flux' not in self.catalogue.columns:
            self.catalogue['peak_flux'] = [np.NaN] * len(self.catalogue)

        cols = ['original_image', 'x', 'y']

        for c in cols[1:]:
            if c not in self.catalogue.columns:
                logging_tools.log("""If a catalogue is provided the x and y
                columns (corresponding to pixel values) must be present""",
                                  level='ERROR')

                raise ValueError("Incorrect input supplied")

        if 'ra' in self.catalogue.columns:
            cols.append('ra')
        if 'dec' in self.catalogue.columns:
            cols.append('dec')
        if 'peak_flux' in self.catalogue.columns:
            cols.append('peak_flux')

        met = {}
        for c in cols:
            met[c] = self.catalogue[c].values

        the_index = np.array(self.catalogue['objid'].values, dtype='str')
        self.metadata = pd.DataFrame(met, index=the_index)
        self.metadata['x'] = self.metadata['x'].astype('int')
        self.metadata['y'] = self.metadata['y'].astype('int')
    def _execute_function(self, df):
        """
        Does the work in actually running the scaler.

        Parameters
        ----------
        df : pd.DataFrame or similar
            The input anomaly scores to rescale.

        Returns
        -------
        pd.DataFrame

        Contains the same original index and columns of the features input 
        with the anomaly score scaled according to the input arguments in 
        __init__.

        """
        print('Running anomaly score rescaler...')

        if self.column_name == 'all':
            cols = df.columns
        else:
            cols = [self.column_name]
        try:
            scores = df[cols]
        except KeyError:
            msg = 'Requested column ' + self.column_name + ' not available in \
                    input dataframe. No rescaling has been performed'

            logging_tools.log(msg, 'WARNING')
            return df

        if self.lower_is_weirder:
            scores = -scores

        scores = (self.new_max - self.new_min) * (scores - scores.min()) / \
            (scores.max() - scores.min()) + self.new_min

        if self.convert_integer:
            scores = round(scores)

        return scores
    def run(self, data):
        """
        This is the external-facing function that should always be called
        (rather than _execute_function). This function will automatically check
        if this stage has already been run with the same arguments and on the
        same data. This can allow a much faster user experience avoiding
        rerunning functions unnecessarily.

        Parameters
        ----------
        data : pd.DataFrame
            Input data on which to run this pipeline stage on.

        Returns
        -------
        pd.DataFrame
            Output
        """
        new_checksum = self.hash_data(data)
        if self.args_same and new_checksum == self.checksum:
            # This means we've already run this function for all instances in
            # the input and with the same arguments
            msg = "Pipeline stage %s previously called " \
                  "with same arguments and same data. Loading from file. " \
                  "Use 'force_rerun=True' in init args to override this " \
                  "behavior." % self.class_name
            logging_tools.log(msg, level='WARNING')
            return self.previous_output
        else:
            msg_string = self.function_call_signature + ' - checksum: ' + \
                (str)(new_checksum)
            # print(msg_string)
            logging_tools.log(msg_string)
            print('Running', self.class_name, '...')
            t1 = time.time()
            if self.drop_nans:
                output = self._execute_function(data.dropna())
            else:
                output = self._execute_function(data)
            self.save(output, self.output_file)
            print('Done! Time taken:', (time.time() - t1), 's')
            return output
def fit_ellipse(contour, image, return_params=False, filled=True):
    """
    Fits an ellipse to a contour and returns a binary image representation of
    the ellipse.

    Parameters
    ----------
    contour : np.ndarray
        Array of x,y values describing the contours (as returned by opencv's
        findCountours function)
    image : np.ndarray
        The original image the contour was fit to.
    return_params : bool
        If true also returns the parameters of the fitted ellipse

    Returns
    -------
    np.ndarray
        2d binary image with representation of the ellipse
    """

    if filled:
        thickness = -1
        y_npix = image.shape[0]
        x_npix = image.shape[1]
        ellipse_arr = np.zeros([y_npix, x_npix], dtype=np.float)
    else:
        thickness = 1
        ellipse_arr = image.copy()

    # Sets some defaults for when the fitting fails
    default_return_params = [np.nan] * 5
    raised_error = False

    try:
        ((x0, y0), (maj_axis, min_axis), theta) = cv2.fitEllipse(contour)
        ellipse_params = x0, y0, maj_axis, min_axis, theta

        if np.any(np.isnan(ellipse_params)):
            raised_error = True
            logging_tools.log('fit_ellipse failed with unknown error:')

    except cv2.error as e:
        logging_tools.log('fit_ellipse failed with cv2 error:' + e.msg)
        raised_error = True

    if raised_error:
        if return_params:
            return ellipse_arr, default_return_params
        else:
            return ellipse_arr

    x0 = int(np.round(x0))
    y0 = int(np.round(y0))
    maj_axis = int(np.round(maj_axis))
    min_axis = int(np.round(min_axis))
    theta = int(np.round(theta))

    cv2.ellipse(ellipse_arr, (x0, y0), (maj_axis // 2, min_axis // 2), theta,
                0, 360, (1, 1, 1), thickness)

    if return_params:
        return ellipse_arr, ellipse_params
    else:
        return ellipse_arr
    def _execute_function(self, image):
        """
        Does the work in actually extracting the ellipse fitted features

        Parameters
        ----------
        image : np.ndarray
            Input image

        Returns
        -------
        array
            Contains the extracted ellipse fitted features
        """

        # First check the array is normalised since opencv will cry otherwise
        if len(image.shape) > 2:
            if self.channel is None:
                raise ValueError('Contours cannot be determined for \
                                  multi-channel images, please set the \
                                  channel kwarg.')
            else:
                this_image = image[:, :, self.channel]
        else:
            this_image = image

        # Get rid of possible NaNs
        # this_image = np.nan_to_num(this_image)

        x0 = y0 = -1
        x_cent = this_image.shape[0] // 2
        y_cent = this_image.shape[1] // 2

        feats = []
        # Start with the closest in contour (highest percentile)
        percentiles = np.sort(self.percentiles)[::-1]

        if np.all(this_image == 0):
            failed = True
            failure_message = "Invalid cutout for feature extraction"
        else:
            failed = False
            failure_message = ""

        for p in percentiles:
            if failed:
                contours = []
            else:
                thresh = np.percentile(this_image[this_image > 0], p)
                contours, hierarchy = find_contours(this_image, thresh)

                x_contours = np.zeros(len(contours))
                y_contours = np.zeros(len(contours))

            # First attempt to find the central point of the inner most contour
            if len(contours) != 0:
                for k in range(len(contours)):
                    M = cv2.moments(contours[k])
                    try:
                        x_contours[k] = int(M["m10"] / M["m00"])
                        y_contours[k] = int(M["m01"] / M["m00"])
                    except ZeroDivisionError:
                        pass
                if x0 == -1:
                    x_diff = x_contours - x_cent
                    y_diff = y_contours - y_cent
                else:
                    x_diff = x_contours - x0
                    y_diff = y_contours - y0

                # Will try to find the CLOSEST contour to the central one
                r_diff = np.sqrt(x_diff**2 + y_diff**2)

                ind = np.argmin(r_diff)

                if x0 == -1:
                    x0 = x_contours[ind]
                    y0 = y_contours[ind]

                c = contours[ind]

                params = get_ellipse_leastsq(c, this_image)
                # Params return in this order:
                # residual, x0, y0, maj_axis, min_axis, theta
                if np.any(np.isnan(params)):
                    failed = True
                else:
                    if params[3] == 0 or params[4] == 0:
                        aspect = 1
                    else:
                        aspect = params[4] / params[3]

                    if aspect < 1:
                        aspect = 1 / aspect
                    if aspect > 100:
                        aspect = 1

                    new_params = params[:3] + [aspect] + [params[-1]]
                    feats.append(new_params)
            else:
                failed = True
                failure_message = "No contour found"

            if failed:
                feats.append([np.nan] * 5)
                logging_tools.log(failure_message)

        # Now we have the leastsq value, x0, y0, aspect_ratio, theta for each
        # sigma
        # Normalise things relative to the highest threshold value
        # If there were problems with any sigma levels, set all values to NaNs
        if np.any(np.isnan(feats)):
            return [np.nan] * 4 * len(self.percentiles)
        else:
            max_ind = np.argmax(self.percentiles)

            residuals = []
            dist_to_centre = []
            aspect = []
            theta = []

            x0_max_sigma = feats[max_ind][1]
            y0_max_sigma = feats[max_ind][2]
            aspect_max_sigma = feats[max_ind][3]
            theta_max_sigma = feats[max_ind][4]

            for n in range(len(feats)):
                prms = feats[n]
                residuals.append(prms[0])
                if prms[1] == 0 or prms[2] == 0:
                    r = 0
                else:
                    x_diff = prms[1] - x0_max_sigma
                    y_diff = prms[2] - y0_max_sigma
                    r = np.sqrt((x_diff)**2 + (y_diff)**2)
                dist_to_centre.append(r)
                aspect.append(prms[3] / aspect_max_sigma)
                theta_diff = np.abs(prms[4] - theta_max_sigma) % 360
                # Because there's redundancy about which way an ellipse
                # is aligned, we always take the acute angle
                if theta_diff > 90:
                    theta_diff -= 90
                theta.append(theta_diff)

            return np.hstack((residuals, dist_to_centre, aspect, theta))
    def run_on_dataset(self, dataset=None):
        """
        This function should be called for pipeline stages that perform feature
        extraction so require taking a Dataset object as input. 
        This is an external-facing function that should always be called
        (rather than _execute_function). This function will automatically check
        if this stage has already been run with the same arguments and on the
        same data. This can allow a much faster user experience avoiding
        rerunning functions unnecessarily.

        Parameters
        ----------
        dataset : Dataset
            The Dataset object on which to run this feature extraction 
            function, by default None

        Returns
        -------
        pd.Dataframe
            Output
        """
        # *** WARNING: this has not been tested against adding new data and
        # *** ensuring the function is called for new data only
        dat = dataset.get_sample(dataset.index[0])
        # Have to do a slight hack if the data is too high dimensional
        if len(dat.shape) > 2:
            dat = dat.ravel()
        new_checksum = self.hash_data(dat)
        if not self.args_same or new_checksum != self.checksum:
            # If the arguments have changed we rerun everything
            msg_string = self.function_call_signature + ' - checksum: ' + \
                (str)(new_checksum)
            logging_tools.log(msg_string)
        else:
            # Otherwise we only run instances not already in the output
            msg = "Pipeline stage %s previously called " \
                "with same arguments. Loading from file. Will only run " \
                "for new samples. Use 'force_rerun=True' in init args " \
                "to override this behavior." % self.class_name
            logging_tools.log(msg, level='WARNING')

        print('Extracting features using', self.class_name, '...')
        t1 = time.time()
        logged_nan_msg = False
        nan_msg = "NaNs detected in some input images." \
                  "NaNs will be set to zero. You can change " \
                  "behaviour by setting drop_nan=False"

        new_index = []
        output = []
        n = 0
        for i in dataset.index:
            if i not in self.previous_output.index or not self.args_same:
                if n % 100 == 0:
                    print(n, 'instances completed')
                input_instance = dataset.get_sample(i)

                if self.drop_nans and np.any(np.isnan(input_instance)):
                    input_instance = np.nan_to_num(input_instance)
                    if not logged_nan_msg:
                        print(nan_msg)
                        logging_tools.log(nan_msg, level='WARNING')
                        logged_nan_msg = True
                out = self._execute_function(input_instance)
                if np.any(np.isnan(out)):
                    logging_tools.log("Feature extraction failed for id " + i)
                output.append(out)
                new_index.append(i)
            n += 1

        new_output = pd.DataFrame(data=output,
                                  index=new_index,
                                  columns=self.labels)

        index_same = new_output.index.equals(self.previous_output.index)
        if self.args_same and not index_same:
            output = pd.concat((self.previous_output, new_output))
        else:
            output = new_output

        if self.save_output:
            self.save(output, self.output_file)
        print('Done! Time taken: ', (time.time() - t1), 's')

        return output
    def __init__(self, *args, **kwargs):
        """
        Base Dataset object that all other dataset objects should inherit from.
        Whenever a child of this class is implemented, super().__init__()
        should be called and explicitly passed all kwargs of the child class,
        to ensure correct logging and saving of files.

        Parameters
        ----------
        filename : str
            If a single file (of any time) is to be read from, the path can be
            given using this kwarg. 
        directory : str
            A directory can be given instead of an explicit list of files. The
            child class will load all appropriate files in this directory.
        list_of_files : list
            Instead of the above, a list of files to be loaded can be
            explicitly given.
        output_dir : str
            The directory to save the log file and all outputs to. Defaults to
            './' 
        """
        self.data_type = None

        if 'filename' in kwargs:
            filename = kwargs['filename']
        else:
            filename = ''
        if 'directory' in kwargs:
            directory = kwargs['directory']
        else:
            directory = ''
        if 'list_of_files' in kwargs:
            list_of_files = kwargs['list_of_files']
        else:
            list_of_files = []
        if len(filename) != 0:
            self.files = [filename]
        elif len(list_of_files) != 0 and len(directory) == 0:
            # Assume the list of files are absolute paths
            self.files = list_of_files
        elif len(list_of_files) != 0 and len(directory) != 0:
            # Assume the list of files are relative paths to directory
            fls = list_of_files
            self.files = [os.path.join(directory, f) for f in fls]
        elif len(directory) != 0:
            # Assume directory contains all the files we need
            fls = os.listdir(directory)
            fls.sort()
            self.files = [os.path.join(directory, f) for f in fls]
        else:
            self.files = []

        # Handles automatic file reading and writing
        if 'output_dir' in kwargs:
            self.output_dir = kwargs['output_dir']
        else:
            self.output_dir = './'

        # This allows the automatic logging every time this class is
        # instantiated (i.e. every time this pipeline stage
        # is run). That means any class that inherits from this base class
        # will have automated logging.

        logging_tools.setup_logger(log_directory=self.output_dir,
                                   log_filename='astronomaly.log')

        class_name = type(locals()['self']).__name__
        function_call_signature = logging_tools.format_function_call(
            class_name, *args, **kwargs)
        logging_tools.log(function_call_signature)
    def __init__(self,
                 fits_index=None,
                 window_size=128,
                 window_shift=None,
                 display_image_size=128,
                 band_prefixes=[],
                 bands_rgb={},
                 transform_function=None,
                 display_transform_function=None,
                 plot_square=False,
                 catalogue=None,
                 plot_cmap='hot',
                 **kwargs):
        """
        Read in a set of images either from a directory or from a list of file
        paths (absolute). Inherits from Dataset class.

        Parameters
        ----------
        filename : str
            If a single file (of any time) is to be read from, the path can be
            given using this kwarg. 
        directory : str
            A directory can be given instead of an explicit list of files. The
            child class will load all appropriate files in this directory.
        list_of_files : list
            Instead of the above, a list of files to be loaded can be
            explicitly given.
        output_dir : str
            The directory to save the log file and all outputs to. Defaults to
            './' 
        fits_index : integer, optional
            If these are fits files, specifies which HDU object in the list to
            work with
        window_size : int, tuple or list, optional
            The size of the cutout in pixels. If an integer is provided, the 
            cutouts will be square. Otherwise a list of 
            [window_size_x, window_size_y] is expected.
        window_shift : int, tuple or list, optional
            The size of the window shift in pixels. If the shift is less than 
            the window size, a sliding window is used to create cutouts. This 
            can be particularly useful for (for example) creating a training 
            set for an autoencoder. If an integer is provided, the shift will 
            be the same in both directions. Otherwise a list of
            [window_shift_x, window_shift_y] is expected.
        display_image_size : The size of the image to be displayed on the
            web page. If the image is smaller than this, it will be
            interpolated up to the higher number of pixels. If larger, it will
            be downsampled.
        band_prefixes : list
            Allows you to specify a prefix for an image which corresponds to a
            band identifier. This has to be a prefix and the rest of the image
            name must be identical in order for Astronomaly to detect these
            images should be stacked together. 
        bands_rgb : Dictionary
            Maps the input bands (in separate folders) to rgb values to allow
            false colour image plotting. Note that here you can only select
            three bands to plot although you can use as many bands as you like
            in band_prefixes. The dictionary should have 'r', 'g' and 'b' as
            keys with the band prefixes as values.
        transform_function : function or list, optional
            The transformation function or list of functions that will be 
            applied to each cutout. The function should take an input 2d array 
            (the cutout) and return an output 2d array. If a list is provided, 
            each function is applied in the order of the list.
        catalogue : pandas.DataFrame or similar
            A catalogue of the positions of sources around which cutouts will
            be extracted. Note that a cutout of size "window_size" will be
            extracted around these positions and must be the same for all
            sources. 
        plot_square : bool, optional
            If True this will add a white border indicating the boundaries of
            the original cutout when the image is displayed in the webapp.
        plot_cmap : str, optional
            The colormap with which to plot the image
        """

        super().__init__(fits_index=fits_index,
                         window_size=window_size,
                         window_shift=window_shift,
                         display_image_size=display_image_size,
                         band_prefixes=band_prefixes,
                         bands_rgb=bands_rgb,
                         transform_function=transform_function,
                         display_transform_function=display_transform_function,
                         plot_square=plot_square,
                         catalogue=catalogue,
                         plot_cmap=plot_cmap,
                         **kwargs)
        self.known_file_types = [
            'fits', 'fits.fz', 'fits.gz', 'FITS', 'FITS.fz', 'FITS.gz'
        ]
        self.data_type = 'image'

        images = {}
        tracemalloc.start()

        if len(band_prefixes) != 0:
            # Get the matching images in different bands
            bands_files = {}

            for p in band_prefixes:
                for f in self.files:
                    if p in f:
                        start_ind = f.find(p)
                        end_ind = start_ind + len(p)
                        flname = f[end_ind:]
                        if flname not in bands_files.keys():
                            bands_files[flname] = [f]
                        else:
                            bands_files[flname] += [f]

            for k in bands_files.keys():
                extension = k.split('.')[-1]
                # print(k, extension)
                if extension == 'fz' or extension == 'gz':
                    extension = '.'.join(k.split('.')[-2:])
                if extension in self.known_file_types:
                    try:
                        astro_img = AstroImage(bands_files[k],
                                               file_type=extension,
                                               fits_index=fits_index,
                                               name=k)
                        images[k] = astro_img

                    except Exception as e:
                        msg = "Cannot read image " + k + "\n \
                            Exception is: " + (str)(e)
                        logging_tools.log(msg, level="ERROR")

            # Also convert the rgb dictionary into an index dictionary
            # corresponding
            if len(bands_rgb) == 0:
                self.bands_rgb = {'r': 0, 'g': 1, 'b': 2}
            else:
                self.bands_rgb = {}
                for k in bands_rgb.keys():
                    band = bands_rgb[k]
                    ind = band_prefixes.index(band)
                    self.bands_rgb[k] = ind
        else:
            for f in self.files:
                extension = f.split('.')[-1]
                if extension == 'fz' or extension == 'gz':
                    extension = '.'.join(f.split('.')[-2:])
                if extension in self.known_file_types:
                    try:
                        astro_img = AstroImage([f],
                                               file_type=extension,
                                               fits_index=fits_index)
                        images[astro_img.name] = astro_img
                    except Exception as e:
                        msg = "Cannot read image " + f + "\n \
                            Exception is: " + (str)(e)
                        logging_tools.log(msg, level="ERROR")

        if len(list(images.keys())) == 0:
            msg = "No images found, Astronomaly cannot proceed."
            logging_tools.log(msg, level="ERROR")
            raise IOError(msg)

        try:
            self.window_size_x = window_size[0]
            self.window_size_y = window_size[1]
        except TypeError:
            self.window_size_x = window_size
            self.window_size_y = window_size

        # Allows sliding windows
        if window_shift is not None:
            try:
                self.window_shift_x = window_shift[0]
                self.window_shift_y = window_shift[1]
            except TypeError:
                self.window_shift_x = window_shift
                self.window_shift_y = window_shift
        else:
            self.window_shift_x = self.window_size_x
            self.window_shift_y = self.window_size_y

        self.images = images
        self.transform_function = transform_function
        if display_transform_function is None:
            self.display_transform_function = transform_function
        else:
            self.display_transform_function = display_transform_function

        self.plot_square = plot_square
        self.plot_cmap = plot_cmap
        self.catalogue = catalogue
        self.display_image_size = display_image_size
        self.band_prefixes = band_prefixes

        self.metadata = pd.DataFrame(data=[])
        if self.catalogue is None:
            self.create_catalogue()
        else:
            self.convert_catalogue_to_metadata()
            print('A catalogue of ', len(self.metadata),
                  'sources has been provided.')

        if 'original_image' in self.metadata.columns:
            for img in np.unique(self.metadata.original_image):
                if img not in images.keys():
                    logging_tools.log('Image ' + img + """ found in catalogue 
                        but not in provided image data. Removing from 
                        catalogue.""",
                                      level='WARNING')
                    msk = self.metadata.original_image == img
                    self.metadata.drop(self.metadata.index[msk], inplace=True)
                    print('Catalogue reduced to ', len(self.metadata),
                          'sources')

        self.index = self.metadata.index.values