예제 #1
0
    def get_dark_frame(self):
        '''
        takes however many dark files that are specified in the pipe.yml and computes the counts/pixel/sec for the sum
        of all the dark obs. This creates a stitched together long dark obs from all of the smaller obs given. This
        is useful for legacy data where there may not be a specified dark observation but parts of observations where
        the filter wheel was closed.

        If self.use_wavecal is True then a dark is not subtracted off since this just  takes into account total counts
        and not energy information

        :return: expected dark counts for each pixel over a flat observation
        '''
        if not self.dark_h5_file_names:
            dark_frame = np.zeros_like(self.spectralCubes[0][:, :, 0])
        else:
            self.dark_start = [self.cfg.flatcal.dark_data['start_times']]
            self.dark_int = [self.cfg.flatcal.dark_data['int_times']]
            self.dark_h5_file_names = [
                os.path.join(self.h5_directory,
                             str(t) + '.h5') for t in self.dark_start
            ]
            frames = np.zeros((140, 146, len(self.dark_start)))
            getLogger(__name__).info('Loading dark frames for Laser flat')

            for i, file in enumerate(self.dark_h5_file_names):
                obs = Photontable(file)
                frame = obs.getPixelCountImage(
                    integrationTime=self.dark_int[i])['image']
                frames[:, :, i] = frame
            total_counts = np.sum(frames, axis=2)
            total_int_time = float(np.sum(self.dark_int))
            counts_per_sec = total_counts / total_int_time
            dark_frame = counts_per_sec
        return dark_frame
예제 #2
0
def fetchimg(ob, kwargs):
    of = Photontable(ob.h5)
    if kwargs['nwvl'] > 1:
        im = of.getSpectralCube(hdu=True, **kwargs)
    else:
        kwargs.pop('wvlN', None)
        im = of.getPixelCountImage(hdu=True, **kwargs)
    del of

    #TODO move to photontable and fetch it from an import
    im.headerh['PIXELA'] = 10.0**2
    return im
예제 #3
0
    def determine_mean_photperRID(self):
        ret = {}
        # return {d:1 for d in self.datasets}

        for d in self.datasets:
            if d in self.count_images:
                cnt_img = self.count_images[d]
            else:
                res = list(self.query_iter((SHORTT_QUERY), dataset=d))
                i = np.array([r.queryt for r in res]).argmin()
                fast_file = res[i].file
                print('This might take a while: {:.0f} s'.format(
                    res[i].queryt))
                of = Photontable(fast_file)
                cnt_img = of.getPixelCountImage(firstSec=30,
                                                integrationTime=5)['image']
                cnt_img *= of.info['expTime'] / 5
                del of
                self.count_images[d] = cnt_img
            ret[d] = cnt_img.sum() / (cnt_img > 0).sum()
        return ret
        planet_photons = np.array(planet_photons).T.astype(np.float32)
        planet_photons = np.insert(planet_photons,
                                   obj=1,
                                   values=cam.phase_cal(ap.wvl_range[0]),
                                   axis=0)
        planet_photons[0] = (planet_photons[0] + abs_step) * sp.sample_time
        photons = np.concatenate((star_photons, planet_photons), axis=1)

        # cam.save_photontable(photonlist=photons, index=None, populate_subsidiaries=False)
        stem = cam.arange_into_stem(photons.T,
                                    (cam.array_size[1], cam.array_size[0]))
        stem = list(map(list, zip(*stem)))
        stem = cam.remove_close(stem)
        photons = cam.ungroup(stem)
        photons = photons[[0, 1, 3, 2]]
        # grid(cam.rebin_list(photons), show=False)
        cam.populate_photontable(photons=photons, finalise=False)
    cam.populate_photontable(photons=[], finalise=True)
    grid(cam.rebin_list(photons), show=False)

    # to look at the photontable
    obs = Photontable(iop.photonlist)
    print(obs.photonTable)

    # to plot an image of all the photons on the array
    image = obs.getPixelCountImage(integrationTime=None)['image']
    quick2D(image, show=False, title='Const planet photons')

    plt.show(block=True)
예제 #5
0
def mask_hot_pixels(file,
                    method='hpm_flux_threshold',
                    step=30,
                    startt=0,
                    stopt=None,
                    ncpu=1,
                    **methodkw):
    """
    This routine is the main code entry point of the bad pixel masking code.
    Takes an obs. file as input and writes a 'bad pixel table' to that h5 file where each entry is an indicator of
    whether the pixel was good, dead, hot, or cold.  Defaults should be somewhat reasonable for a typical on-sky image.
    HPCut method is interchangeable with any of the methods listed here.

    The HOT and DEAD masks are combined into a single BAD mask at the end

    Required Input:
    :param obsfile:           user passes an obsfile instance here
    :param startt         Scalar Integer.  Timestamp at which to begin bad pixel masking, default = 0
    :param stopt           Scalar Integer.  Timestamp at which to finish bad pixel masking, default = -1
                                               (run to the end of the file)
    :param step          Scalar Integer.  Number of seconds to do the bad pixel masking over (should be an integer
                                               number of steps through the obsfile), default = 30
    :param hpcutmethod        String.          Method to use to detect hot pixels.  Options are:
                                               hpm_median_movingbox
                                               hpm_flux_threshold
                                               hpm_laplacian
                                               hpm_cps_cut

    Other Input:
    Appropriate args and kwargs that go into the chosen hpcut function

    :return:
    Applies relevant pixelflags - see pixelflags.py
    """
    obs = Photontable(file)
    if obs.info['isBadPixMasked']:
        getLogger(__name__).info('{} is already bad pixel calibrated'.format(
            obs.fileName))
        return

    if stopt is None:
        stopt = obs.getFromHeader('expTime')
    assert startt < stopt
    if step > stopt - startt:
        getLogger(__name__).warning((
            'Hot pixel step time longer than exposure time by {:.0f} s, using full '
            'exposure').format(abs(stopt - startt - step)))
        step = stopt - startt

    step_starts = np.arange(
        startt, stopt, step,
        dtype=int)  # Start time for each step (in seconds).
    step_ends = step_starts + int(step)  # End time for each step
    step_ends[step_ends > stopt] = int(
        stopt
    )  # Clip any time steps that run over the end of the requested time range.

    # Initialise stack of masks, one for each time step
    hot_masks = np.zeros([obs.nXPix, obs.nYPix, step_starts.size], dtype=bool)
    cold_masks = np.zeros([obs.nXPix, obs.nYPix, step_starts.size], dtype=bool)
    func = globals()[method]

    # Generate a stack of bad pixel mask, one for each time step
    for i, each_time in enumerate(step_starts):
        getLogger(__name__).info('Processing time slice: {} - {} s'.format(
            each_time, each_time + step))
        raw_image_dict = obs.getPixelCountImage(
            firstSec=each_time,
            integrationTime=step,
            applyWeight=True,
            applyTPFWeight=True,
            scaleByEffInt=method == 'hpm_cps_cut')
        bad_pixel_solution = func(raw_image_dict['image'], **methodkw)
        hot_masks[:, :, i] = bad_pixel_solution['hot_mask']
        cold_masks[:, :, i] = bad_pixel_solution['cold_mask']
    unstable_mask = np.zeros((obs.nXPix, obs.nYPix), dtype=bool)
    for x in range(obs.nXPix):
        for y in range(obs.nYPix):
            vals = np.zeros(len(hot_masks[x, y, :]), dtype=bool)
            for i, mask in enumerate(hot_masks[x, y, :]):
                vals[i] = mask
            if not all(vals) or all(vals):
                unstable_mask[x, y] = False
            else:
                unstable_mask[x, y] = True

    # Combine the bad pixel masks into a master mask
    obs.enablewrite()
    obs.applyBadPixelMask(np.all(hot_masks, axis=-1),
                          np.all(cold_masks, axis=-1), unstable_mask)
    obs.disablewrite()
예제 #6
0
def get_transforms(ditherfile,
                   datadir,
                   wvl_start=None,
                   wvl_stop=None,
                   fwhm_guess=3.0,
                   fit_power=1,
                   CONEX_ERROR=0.0001,
                   plot=False):
    dither = MKIDDitheredObservation(os.path.basename(ditherfile), ditherfile,
                                     None, None)
    obs_files = [
        os.path.join(datadir, '{}.h5'.format(o.start)) for o in dither.obs
    ]

    box_size = fwhm_guess * 10

    debug_images = []
    pixel_positions = []
    source_est = []

    for file in obs_files:
        obs = Photontable(file)
        data = obs.getPixelCountImage(applyWeight=False,
                                      exclude_flags=pixelflags.PROBLEM_FLAGS,
                                      wvlStart=wvl_start,
                                      wvlStop=wvl_stop)['image']
        data = np.transpose(data)
        debug_images.append(data)
        mean, median, std = stats.sigma_clipped_stats(data,
                                                      sigma=3.0,
                                                      mask_value=0)
        mask = np.zeros_like(data, dtype=bool)
        mask[data == 0] = True

        sources = DAOStarFinder(fwhm=fwhm_guess,
                                threshold=5. * std)(data - median, mask=mask)
        source = sources[sources['flux'].argmax()]
        source_est.append((source['xcentroid'], source['ycentroid']))

        position = centroids.centroid_sources(data,
                                              source['xcentroid'],
                                              source['ycentroid'],
                                              box_size=int(box_size),
                                              centroid_func=centroid_2dg,
                                              mask=mask)
        pixel_positions.append((position[0][0], position[1][0]))

    pixel_positions = np.array(pixel_positions)
    conex_positions = np.array(dither.pos)

    xform_con2pix = tf.estimate_transform('polynomial',
                                          conex_positions,
                                          pixel_positions,
                                          order=fit_power)
    xform_pix2con = tf.estimate_transform('polynomial',
                                          pixel_positions,
                                          conex_positions,
                                          order=fit_power)

    if plot:
        axis = int(round(np.sqrt(len(obs_files)), 0))
        fig, axs = plt.subplots(axis, axis, figsize=(20, 15))
        i = 0
        j = 0
        for index, image in enumerate(debug_images[:-1]):
            axs[i, j].imshow(image,
                             origin='lower',
                             interpolation='nearest',
                             cmap='viridis')
            axs[i, j].add_patch(
                Rectangle(
                    (source_est[index][0] -
                     (box_size / 2), source_est[index][1] - (box_size / 2)),
                    box_size,
                    box_size,
                    linewidth=1,
                    edgecolor='r',
                    fill=None))
            marker = '+'
            ms, mew = 30, 2.
            axs[i, j].plot(source_est[index][0],
                           source_est[index][1],
                           color='red',
                           marker=marker,
                           ms=ms,
                           mew=mew)
            axs[i, j].plot(pixel_positions[index][0],
                           pixel_positions[index][1],
                           color='blue',
                           marker=marker,
                           ms=ms,
                           mew=mew)
            axs[i, j].set_title(
                'Red + = Estimate, Blue + = Centroid for Dither Pos %i' %
                index)
            if (index + 1) % axis == 0 and (index + 1) != len(obs_files):
                i += 1
                j = 0
            elif index != len(obs_files) - 1:
                j += 1
        plt.show()

        _plotresiduals(xform_con2pix, conex_positions, pixel_positions)

    return xform_con2pix, xform_pix2con