Ejemplo n.º 1
0
    def _loadDataSet(self, event_type, flag, size):
        session: Session = DEDSession()
        part_size = int(size / 2)

        event_filter = Map.events.any(Event.type == event_type)

        query = session.query(Map).filter(
            Map.path != None, Map.flag == flag,
            event_filter).order_by(func.random()).with_entities(Map.id)
        result = query.limit(part_size).all()
        event_map_ids = list(map(lambda x: x[0], result))

        query = session.query(Map).filter(
            Map.path != None, Map.flag == flag,
            not_(event_filter)).order_by(func.random()).with_entities(Map.id)
        result = query.limit(part_size).all()
        none_ids = list(map(lambda x: x[0], result))

        session.close()

        data = [[id, [0, 1]] for id in event_map_ids] + [[id, [1, 0]]
                                                         for id in none_ids]
        if len(data) < size:
            raise Exception("Not enough entries available!")

        random.shuffle(data)
        return data
Ejemplo n.º 2
0
def test_imageanimator_figure():
    AIA_171 = sunpy.data.test.get_test_filepath('aia_171_level1.fits')
    KCOR = sunpy.data.test.get_test_filepath(
        '20181209_180305_kcor_l1.5_rebinned.fits')
    map_seuence = sunpy.map.Map(AIA_171, KCOR, sequence=True)
    sequence_array = map_seuence.as_array()
    wcs_input_dict = {
        f'{key}{n+1}': map_seuence.all_meta()[0].get(f'{key}{n}')
        for n, key in product([1, 2], ['CTYPE', 'CUNIT', 'CDELT'])
    }
    t0, t1 = map(parse_time, [k['date-obs'] for k in map_seuence.all_meta()])
    time_diff = (t1 - t0).to(u.s)
    wcs_input_dict.update({
        'CTYPE1': 'Time',
        'CUNIT1': time_diff.unit.name,
        'CDELT1': time_diff.value
    })
    wcs = astropy.wcs.WCS(wcs_input_dict)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", SunpyDeprecationWarning)
        wcs_anim = ImageAnimatorWCS(sequence_array,
                                    wcs=wcs,
                                    vmax=1000,
                                    image_axes=[0, 1])

    return wcs_anim.fig
Ejemplo n.º 3
0
def polsfromfitsheader(header):
    '''
    get polarisation information from fits header
    :param header: fits header
    :return pols: polarisation stokes
    '''
    try:
        stokeslist = [
            '{}'.format(int(ll))
            for ll in (header["CRVAL4"] +
                       np.arange(header["NAXIS4"]) * header["CDELT4"])
        ]
        stokesdict = {
            '1': 'I',
            '2': 'Q',
            '3': 'U',
            '4': 'V',
            '-1': 'RR',
            '-2': 'LL',
            '-3': 'RL',
            '-4': 'LR',
            '-5': 'XX',
            '-6': 'YY',
            '-7': 'XY',
            '-8': 'YX'
        }
        pols = map(lambda x: stokesdict[x], stokeslist)
    except:
        print("error in fits header!")
    return pols
Ejemplo n.º 4
0
def combine_maps(maps_list):


    # Combined maps_list
    shape_out = (180, 360)  # This is set deliberately low to reduce memory consumption
    header = sunpy.map.make_fitswcs_header(shape_out,
                                           SkyCoord(0, 0, unit=u.deg,
                                                    frame="heliographic_stonyhurst",
                                                    obstime=maps_list[0].date),
                                           scale=[180 / shape_out[0],
                                                  360 / shape_out[1]] * u.deg / u.pix,
                                           wavelength=int(maps_list[0].meta['wavelnth']) * u.AA,
                                           projection_code="CAR")
    out_wcs = WCS(header)
    coordinates = tuple(map(sunpy.map.all_coordinates_from_map, maps_list))
    weights = [coord.transform_to("heliocentric").z.value for coord in coordinates]
    weights = [(w / np.nanmax(w)) ** 3 for w in weights]
    for w in weights:
        w[np.isnan(w)] = 0

    array, _ = reproject_and_coadd(maps_list, out_wcs, shape_out,
                                   input_weights=weights,
                                   reproject_function=reproject_interp,
                                   match_background=True,
                                   background_reference=0)
    outmaps = sunpy.map.Map((array, header))
    return outmaps
Ejemplo n.º 5
0
 def get_interpolation_value(self, date_timestamp):
     start_dates = list(
         map(
             lambda x: pd.to_datetime(x[:-4], format='%Y-%m-%dT%H:%M:%S').
             timestamp(), self.interpolation_values.keys()))
     eff_area = list(self.interpolation_values.values())
     return np.interp(date_timestamp, start_dates, eff_area)
Ejemplo n.º 6
0
def rotation(lo, la, v, smooth):
    q1 = Quaternion(axis=[0.0, 0.0, 1.0], degrees=lo)
    q2 = Quaternion(axis=[0.0, 1.0, 0.0], degrees=la)
    q_rot = q2 * q1

    rot_matrix = q_rot.rotation_matrix
    format_v = np.array(list(zip(np.ravel(v[0]), np.ravel(v[1]), np.ravel(v[2]))))
    format_v = format_v @ rot_matrix

    return np.array(list(map(lambda x: np.reshape(x, (smooth, smooth)), zip(*format_v))))
Ejemplo n.º 7
0
def encode_and_split(chain_codes):
    codes = []

    for chains in chain_codes:
        if type(chains) is bytes:
            chains = chains.decode("utf-8")

        splitted_chain = list(map(int, str(chains)))
        codes.append(splitted_chain)

    return codes
Ejemplo n.º 8
0
sequence_array = map_sequence.as_array()

###############################################################################
# Now we need to create the `~astropy.wcs.WCS` header that
# `~sunpy.visualization.animator.ImageAnimatorWCS` will need.
# To create the new header we can use the stored meta information from the
# ``map_sequence``.

# This dictionary comphersion is extracting the three basic keywords we need
# to create a astropy.wcs.WCS header: 'CTYPE','CUNIT' and 'CDELT'
# from the meta information stored in the 'map_sequence'.
wcs_input_dict = {f'{key}{n+1}': map_sequence.all_meta()[0].get(f'{key}{n}')
                  for n, key in product([1, 2], ['CTYPE', 'CUNIT', 'CDELT'])}

# Now we need to get the time difference between the two observations.
t0, t1 = map(parse_time, [k['date-obs'] for k in map_sequence.all_meta()])
time_diff = (t1 - t0).to(u.s)
wcs_input_dict.update({'CTYPE1': 'Time', 'CUNIT1': time_diff.unit.name, 'CDELT1': time_diff.value})

# We can now just pass this into astropy.wcs.WCS to create our WCS header.
wcs = astropy.wcs.WCS(wcs_input_dict)

# Now the resulting WCS object will look like:
print(wcs)

###############################################################################
# Now we can create the animation.
# `~sunpy.visualization.animator.ImageAnimatorWCS` assumes the last two axes
# are the two that form the image. However, they are the first two in this case,
# so we change this by passing in the ``image_axes`` keyword.
Ejemplo n.º 9
0
(Optional show or save fig)
"""

import sunpy.io
import sunpy.map
import sunpy.data.sample
import matplotlib.pyplot as plt
from numpy import asarray
import numpy as np
from numpy import savez_compressed

base = "/Users/lxy/Desktop/Rice/PHYS 491 & 493 Research/data/"
years = ["2019"]
months = [str(i + 1) for i in range(12)]
months = list(map(lambda i: '0' + i if len(i) == 1 else i, months))
days = [str(i + 1) for i in range(31)]
days = list(map(lambda i: '0' + i if len(i) == 1 else i, days))

a1 = -7.31 * 10**(-2)
a2 = -9.75 * 10**(-1)
a3 = -9.90 * 10**(-2)
a4 = -2.84 * 10**(-3)
f = 0.31

# Instead of manually search, I find out a threshhold
ignore_list = [
    "20190101", "20190110", "20190112", "20190113", "20190114", "20190115",
    "20190116", "20190118", "20190119", "20190120", "20190228", "20190307",
    "20190327", "20190328", "20190330", "20190331", "20190422", "20190423",
    "20190424", "20190426", "20190427", "20190428", "20190429", "20190430",
Ejemplo n.º 10
0
# The easiest method  for the array is to create a `~sunpy.map.MapSequence`.

# Here we only use two files but you could pass in a larger selection of files.
map_sequence = sunpy.map.Map(AIA_171_IMAGE, AIA_193_IMAGE, sequence=True)

# Now we can just cast the sequence away into a NumPy array.
sequence_array = map_sequence.as_array()

###############################################################################
# Now we need to create the `~astropy.wcs.WCS` header that
# `~sunpy.visualization.animator.ArrayAnimatorWCS` will need.
# To create the new header we can use the stored meta information from the
# ``map_sequence``.

# Now we need to get the time difference between the two observations.
t0, t1 = map(parse_time, [k['date-obs'] for k in map_sequence.all_meta()])
time_diff = (t1 - t0).to(u.s)

m = map_sequence[0]

wcs = astropy.wcs.WCS(naxis=3)
wcs.wcs.crpix = u.Quantity([0*u.pix] + list(m.reference_pixel))
wcs.wcs.cdelt = [time_diff.value] + list(u.Quantity(m.scale).value)
wcs.wcs.crval = [0, m._reference_longitude.value, m._reference_latitude.value]
wcs.wcs.ctype = ['TIME'] + list(m.coordinate_system)
wcs.wcs.cunit = ['s'] + list(m.spatial_units)
wcs.rsun = m.rsun_meters
wcs.heliographic_observer = m.observer_coordinate

# Now the resulting WCS object will look like:
print(wcs)
Ejemplo n.º 11
0
fits_files_temp = []
#loop over all wavelengths in array
for i in wav:
    try:
        #fits_files = glob(arch+'*_'+i+'.fits')
        fits_files = [
            j.strftime(arch + 'AIA%Y%m%d_%H%M_' + i + '.fits')
            for j in real_cad
        ]
    except:
        fits_files = glob(sdir + '/xrt/*fits')
    #make sure the wavelength header agrees with found value
    #fits_files = check_wavelength(fits_files,i,arch)
    fits_files_temp.append(fits_files)
#transpose list array
fits_files = map(list, zip(*fits_files_temp))
print(len(fits_files))

#use font
#font = ImageFont.truetype("/Library/Fonts/Times New Roman Bold.ttf", 56)
font = ImageFont.truetype("/Library/Fonts/Arial Unicode.ttf", 56)
font = ImageFont.truetype(
    "/Volumes/Pegasus/jprchlik/anaconda2/lib/python2.7/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans-Bold.ttf",
    56)

#number of processors
n_proc = 8

#image and movie height and width
wy = 2048
wx = 2048
Ejemplo n.º 12
0
            # Combined maps
            shape_out = (
                180, 360
            )  # This is set deliberately low to reduce memory consumption
            header = sunpy.map.make_fitswcs_header(
                shape_out,
                SkyCoord(0,
                         0,
                         unit=u.deg,
                         frame="heliographic_stonyhurst",
                         obstime=maps[0].date),
                scale=[180 / shape_out[0], 360 / shape_out[1]] * u.deg / u.pix,
                wavelength=int(maps[0].meta['wavelnth']) * u.AA,
                projection_code="CAR")
            out_wcs = WCS(header)
            coordinates = tuple(map(sunpy.map.all_coordinates_from_map, maps))
            weights = [
                coord.transform_to("heliocentric").z.value
                for coord in coordinates
            ]
            weights = [(w / np.nanmax(w))**3 for w in weights]
            for w in weights:
                w[np.isnan(w)] = 0

            array, _ = reproject_and_coadd(maps,
                                           out_wcs,
                                           shape_out,
                                           input_weights=weights,
                                           reproject_function=reproject_interp,
                                           match_background=True,
                                           background_reference=0)
Ejemplo n.º 13
0
import astropy.units as u
import matplotlib.pyplot as plt
from astropy.time import Time
from astropy.table import QTable, vstack
from skimage.measure import regionprops, label, regionprops_table
from skimage.color import label2rgb

import sunpy.map

from stara import stara

maps = sunpy.map.Map("./data/*720*")
maps = [m.resample((1024, 1024) * u.pix) for m in maps]

segs = list(map(partial(stara, limb_filter=10 * u.percent), maps))

def get_regions(segmentation, smap):
    labelled = label(segmentation)
    if labelled.max() == 0:
        return QTable()

    regions = regionprops_table(labelled, smap.data,
                                properties=["label",
                                            "centroid",
                                            "area",
                                            "min_intensity"])

    regions['obstime'] = Time([smap.date] * regions['label'].size)
    regions['center_coord'] = smap.pixel_to_world(regions['centroid-0'] * u.pix,
                                                  regions['centroid-1'] * u.pix).heliographic_stonyhurst
Ejemplo n.º 14
0
 def showQuery(self):
     attrs = reduce(lambda x, y: x | y,
                    map(lambda z: z._LayerHandler__fullAttr, self.layers))
     return self.client.query(attrs)
Ejemplo n.º 15
0
# ordering of the sorted_wavelengths list has to match the axes generation.
axes = {wave:ax for wave, ax in zip(sorted_wavelengths, left_axes+right_axes+centre_axes)}

def mapwplot(key):
    """
    A function to plot the map and write the wavelength in the bottom left corner
    """
    # We do this manually because of a bug in map.plot()
    axes[key].imshow(maps[key].data, cmap=maps[key].cmap, norm=maps[key].mpl_color_normalizer,
                     extent=maps[key].xrange + maps[key].yrange, origin='lower')
    if key not in (6173.0, sorted_wavelengths[-1]): # Not the main frame or HMI
        axes[key].text(bx+label_padx, by+label_pady, (key*u.AA).to(u.nm)._repr_latex_(),
                       color='w', fontdict={'fontsize':16})

# Run the above function on every key in the maps dictionary.
map(mapwplot, maps.keys())

# Remove all tick labels from all axes.
[ax.axes.get_xaxis().set_ticks([]) for ax in axes.values()]
[ax.axes.get_yaxis().set_ticks([]) for ax in axes.values()]

# Manually write on the text for the main image and the HMI image.
axes[6173.].text(bx+label_padx, by+label_pady, r"LOS Magnetic Field", color='w',
                 fontdict={'fontsize':16})
axes[sorted_wavelengths[-1]].text(-950, -990, (1700*u.AA).to(u.nm)._repr_latex_(),
                                  color='w', fontdict={'fontsize':16})

# Add a rectangle to the main image showing the crop box for the satellite images.
re = plt.Rectangle((bx,by), w, h, fc='none')
axes[sorted_wavelengths[-1]].add_patch(re)
Ejemplo n.º 16
0
    def __getitem__(self, idx):
        # sampling with probability from SIDC
        print("Sampling from SIDC...")
        row = self.sidc_csv.sample(weights=self.sidc_csv[4])
        day = '/'.join(map(str, row.iloc[0][:-1]))
        date = datetime.strptime(day + ' 12:00:00', '%Y/%m/%d %H:%M:%S')

        # loading sunspot data from DPD
        print("Loading sunspot data...")
        dpd = self.fenyi_sunspot.query(("year == @date.year & "
                                        "month == @date.month & "
                                        "day == @date.day"))

        time = datetime.strptime(
            '-'.join([str(i) for i in list(dpd.iloc[0])[1:7]]),
            '%Y-%m-%d-%H-%M-%S')
        start_time = (time -
                      timedelta(minutes=30)).strftime('%Y-%m-%dT%H:%M:%S')
        end_time = (time + timedelta(minutes=30)).strftime('%Y-%m-%dT%H:%M:%S')

        try:
            print("Searching VSO...")
            continuum_file, magnetic_file = search_VSO(start_time, end_time)
            hmi_cont = Map(continuum_file)
            hmi_mag = Map(magnetic_file)
        except Exception as e:
            print(e)
            remove_if_exists(continuum_file)
            remove_if_exists(magnetic_file)
            return self.__getitem__(idx)

        # get the data from the maps
        img_cont = normalize_map(hmi_cont)
        img_mag = 2 * normalize_map(hmi_mag) - 1
        inputs = np.array([img_cont, img_mag])

        # get the coordinates and the date of the sunspots from DPD
        print("Creating mask...")
        ss_coord = dpd[['heliographic_latitude', 'heliographic_longitude']]
        ss_date = parse_time(time)
        sunspots = rotate_coord(hmi_cont, ss_coord, ss_date)

        # mask = (255 * img_cont).astype(np.uint8)
        mask = np.zeros(img_cont.shape, dtype=np.float32)

        ws = dpd[['projected_whole_spot', 'group_number', 'group_spot_number']]

        for index, row in ws.iterrows():
            wsa = row['projected_whole_spot']
            if wsa < 0:
                match = ws.query(("group_number == @row.group_number & "
                                  "group_spot_number == -@wsa"))
                area = match['projected_whole_spot'].iloc[0]
                ws.loc[row.name, 'projected_whole_spot'] = area

        groups = list(ws['group_number'].unique())
        disk_mask = np.where(255 * img_cont > 15)
        disk_mask = {(c[0], c[1]) for c in np.column_stack(disk_mask)}
        disk_mask_num_px = len(disk_mask)
        whole_spot_mask = set()

        for i in range(len(sunspots)):
            o = 4  # offset
            p = sunspots[i]
            # g_number = groups.index(ws.iloc[i]['group_number'])
            group = img_cont[int(p[1]) - o:int(p[1]) + o,
                             int(p[0]) - o:int(p[0]) + o]
            low = np.where(group == np.amin(group))

            center = (img_cont.shape[0] / 2, img_cont.shape[1] / 2)
            distance = np.linalg.norm(tuple(j - k for j, k in zip(center, p)))
            cosine_amplifier = math.cos(math.radians(1) * distance / center[0])
            norm_num_px = cosine_amplifier * ws.iloc[i]['projected_whole_spot']
            ss_num_px = 8.7 * norm_num_px * disk_mask_num_px / 10e6

            new = set([(p[1] - o + low[1][0], p[0] - o + low[0][0])])
            whole_spot = set()
            candidates = dict()
            expansion_rate = 3
            while len(whole_spot) < ss_num_px:
                expand = {(n[0] + i, n[1] + j)
                          for i in [-1, 0, 1] for j in [-1, 0, 1] for n in new}
                for e in set(expand - whole_spot):
                    candidates[e] = img_cont[e]
                new = sorted(candidates, key=candidates.get)[:expansion_rate]
                for n in new:
                    candidates.pop(n, None)
                whole_spot.update(set(new))

            whole_spot_mask.update(whole_spot)

        for c in set.intersection(whole_spot_mask, disk_mask):
            mask[c] = 1

        # show_mask(img_cont, mask)
        remove_if_exists(continuum_file)
        remove_if_exists(magnetic_file)

        input_batch, mask_batch = multi_scale_slice(inputs, mask)

        data_pair = {
            'img': torch.from_numpy(input_batch).float(),
            'mask': torch.from_numpy(mask_batch).float()
        }
        return data_pair
Ejemplo n.º 17
0
    def read_all_images(self,
                        save_folder,
                        save_name,
                        in_types=['ew'],
                        out_types=['0304'],
                        new_size=(864, 864),
                        remake=False,
                        folder_indices=None,
                        correct_sensor_data=True,
                        remake_images=False):
        '''This function reads all entries in the processed folder. It stores
        the resulting images in a pandas dataframe which contains a column for
        each image type, along with the date of that entry, which can be thought
        of as a unique identifier for the data point. This function also removes
        a number of pixels from the '0304' image type so that the sun takes up the
        same amount of space in the image frame as the 'ew' image type. All images
        are rescaled to a uniform size. If something goes wrong in this process,
        i.e. something went wrong with the reading process, the program skips the
        entry and moves on.
        
        Parameters
        ----------
        
        allowed_types : list(str)
            A list of strings indicating the allowed image types
        new_size : [int, int]
            A list representing the rescaled [width, height] of the processed images
            
        Returns
        -------
        
        df : pd.DataFrame
            A dataframe containg columns of 'allowed_types' and a date column
        '''
        save_location = save_folder / save_name
        temp_location = save_folder / 'temp.h5'

        allowed_types = in_types + out_types
        remake = remake or not save_location.exists()
        if remake:
            if save_location.exists():
                save_location.unlink()
            if temp_location.exists():
                temp_location.unlink()
            i, created_dataset = -1, False
            num_folders = sum([1 for _ in self.processed_path.iterdir()])
            with tqdm(total=num_folders,
                      desc='Processing dataset folders') as pbar:
                with h5py.File(temp_location, 'w') as hf:
                    data_shape = (num_folders, new_size[0], new_size[1])
                    hf.create_dataset('date', shape=(num_folders, ))
                    for type_name in allowed_types:
                        hf.create_dataset(type_name,
                                          shape=data_shape,
                                          dtype=np.float32)
                        #hf[type_name].attrs['date'] = image_dict['date'].value
                with h5py.File(temp_location, 'a') as hf:
                    for folder in self.processed_path.iterdir():
                        pbar.update(1)
                        try:
                            i += 1
                            if folder_indices is not None:
                                if i < folder_indices[0] or i > folder_indices[
                                        1]:
                                    continue
                            if not self.check_valid_fits_folder(
                                    folder, allowed_types):
                                continue
                            entry_date = pd.to_datetime(
                                folder.parts[-1], format="%Y-%m-%d-H%H-M%M")
                            image_dict = {}
                            image_dict['date'] = entry_date
                            hf['date'][i] = image_dict['date'].value / 10**9
                            for image in self.get_images(
                                    folder, allowed_types):
                                data = image.data
                                #Remove 80 pixels to have the sun take about the same space in the frame
                                #as the 'ew' images
                                if image.image_type in ['0171', '0304']:
                                    height, width = data.shape
                                    pixel_remove_width = 80
                                    data = data[pixel_remove_width:(
                                        height - pixel_remove_width),
                                                pixel_remove_width:(
                                                    width -
                                                    pixel_remove_width)]
                                if correct_sensor_data and image.image_type == '0304':
                                    data = data * 1 / self.get_interpolation_value(
                                        hf['date'][i])
                                image_dict[image.image_type] = resize(
                                    data, (new_size[0], new_size[1]))
                                image_dict[image.image_type] = image_dict[
                                    image.image_type].astype(dtype='float32')
                            #print("Read Folder Images!")
                            n_read = 0

                            for type_name in allowed_types:
                                hf[type_name][i, :, :] = image_dict[type_name]
                                n_read = hf[type_name].shape[0]
                            #print("Wrote images to hf!")
                            #hf['date'].resize((hf['date'].shape[0] + 1), axis=0)
                            #hf['date'][-1:] = image_dict['date'].value / 10 ** 9
                            #for type_name in allowed_types:
                            #    hf[type_name].resize((hf[type_name].shape[0] + 1), axis=0)
                            #    print(hf[type_name][-1:].shape)
                            #    print(image_dict[type_name].shape)
                            #    hf[type_name][-1:] = image_dict[type_name]
                            #    n_read = hf[type_name].shape[0]

                            #print('Type Name: ' + type_name + ' Shape: ' + str(hf[type_name].shape))
                            #print("Read " + str(n_read) + " images!")

                        except KeyboardInterrupt:
                            print("Quitting...")
                            break

                #except:
                #    print('Failed to read folder: ' + str(i))

            #print("Read " + str(n_read) + " images!")
        if remake_images or remake:
            if save_location.exists():
                save_location.unlink()
            excluded_dates_list = list(
                map(
                    lambda x: pd.to_datetime(x, format="%Y-%m-%d-H%H-M%M").
                    value / 10**9, self.excluded_dates))

            to_keep = []
            eps = 1e-7
            with h5py.File(temp_location, 'a') as hf:
                n_read = hf['date'].shape[0]
                for idx in tqdm(range(n_read), desc='Checking for bad images'):
                    date = hf['date'][idx]
                    matching_times = list(
                        filter(
                            lambda x: np.abs(x) < 100,
                            list(map(lambda x: x - date,
                                     excluded_dates_list))))
                    date_check = len(matching_times) == 0 and date != 0
                    #print(date_check)
                    nan_check, std_check = True, True
                    for type_name in allowed_types:
                        arr = hf[type_name][idx]
                        nan_check = nan_check and not np.isnan(arr.mean())
                        std_check = std_check and np.std(arr) > eps
                    if nan_check and std_check and date_check:
                        to_keep.append(idx)

                with h5py.File(save_location, 'w') as hf_save:
                    data_shape = (len(to_keep), 1, new_size[0], new_size[1])
                    hf_save.create_dataset('date',
                                           shape=(len(to_keep), ),
                                           chunks=True)
                    for in_type in in_types:
                        hf_save.create_dataset(in_type,
                                               shape=data_shape,
                                               dtype=np.float32,
                                               chunks=True)
                    for out_type in out_types:
                        hf_save.create_dataset(out_type,
                                               shape=data_shape,
                                               dtype=np.float32,
                                               chunks=True)
                    for i in tqdm(range(len(to_keep)),
                                  desc='Copying good images'):
                        for in_type in in_types:
                            hf_save[in_type][i, 0, :, :] = hf[in_type][
                                to_keep[i], :, :]
                        for out_type in out_types:
                            hf_save[out_type][i, 0, :, :] = hf[out_type][
                                to_keep[i], :, :]
                        hf_save['date'][i] = hf['date'][to_keep[i]]
                    '''
                    in_data_shape = (len(to_keep), len(in_types), new_size[0], new_size[1])
                    out_data_shape = (len(to_keep), len(out_types), new_size[0], new_size[1])
                    hf_save.create_dataset('x', shape=in_data_shape, dtype=np.float32, chunks=True)
                    hf_save.create_dataset('y', shape=out_data_shape, dtype=np.float32, chunks=True)
                    hf_save.create_dataset('date', shape=(len(to_keep),), chunks=True)
                    for i in range(len(to_keep)):
                        for in_type_idx in range(len(in_types)):
                            hf_save['x'][i, in_type_idx, :, :] = hf[in_types[in_type_idx]][to_keep[i], :, :]
                        for out_type_idx in range(len(out_types)):
                            hf_save['y'][i, out_type_idx, :, :] = hf[out_types[out_type_idx]][to_keep[i], :, :]
                        hf_save['date'][i] = hf['date'][to_keep[i]]
                    '''
                    print("Kept " + str(len(to_keep)) + " images!")