def __next__(self): """ Returns: draw_blend_generator output, deblender output and measurement output. """ blend_output = next(self.draw_blend_generator) deblend_results = {} measured_results = {} input_args = [(blend_output, i) for i in range(self.batch_size)] batch_results = multiprocess( self.run_batch, input_args, self.cpus, self.multiprocessing, self.verbose, ) for i in range(self.batch_size): deblend_results.update({i: batch_results[i][0]}) measured_results.update({i: batch_results[i][1]}) if self.verbose: print("Measurement performed on batch") return blend_output, deblend_results, measured_results
def __next__(self): """Outputs dictionary containing blend output (images and catalogs) in batches. Returns: output: Dictionary with blend images, isolated object images, blend catalog, PSF images and WCS. """ blend_list = {} blend_images = {} isolated_images = {} blend_cat = next(self.blend_generator) mini_batch_size = np.max([self.batch_size // self.cpus, 1]) psfs = {} wcss = {} for s in self.surveys: pix_stamp_size = int(self.stamp_size / s.pixel_scale) # make PSF and WCS psf = [] for filt in s.filters: if callable(filt.psf): generated_psf = filt.psf( ) # generate the PSF with the provided function if isinstance(generated_psf, galsim.GSObject): psf.append(generated_psf) else: raise TypeError( f"The generated PSF with the provided function" f"for filter '{filt.name}' is not a galsim object") elif isinstance(filt.psf, galsim.GSObject): psf.append(filt.psf) # or directly retrieve the PSF else: raise TypeError( f"The PSF within filter '{filt.name}' is neither a " f"function nor a galsim object") wcs = make_wcs(s.pixel_scale, (pix_stamp_size, pix_stamp_size)) psfs[s.name] = psf wcss[s.name] = wcs input_args = [] seedseq_minibatch = self.seedseq.spawn(self.batch_size // mini_batch_size + 1) for i in range(0, self.batch_size, mini_batch_size): cat = copy.deepcopy(blend_cat[i:i + mini_batch_size]) input_args.append((cat, psf, wcs, s, seedseq_minibatch[i // mini_batch_size])) # multiprocess and join results # ideally, each cpu processes a single mini_batch mini_batch_results = multiprocess( self.render_mini_batch, input_args, cpus=self.cpus, verbose=self.verbose, ) # join results across mini-batches. batch_results = list(chain(*mini_batch_results)) # decide image_shape based on channels_last bool. option1 = (len(s.filters), pix_stamp_size, pix_stamp_size) option2 = (pix_stamp_size, pix_stamp_size, len(s.filters)) image_shape = option1 if not self.channels_last else option2 # organize results. blend_images[s.name] = np.zeros((self.batch_size, *image_shape)) isolated_images[s.name] = np.zeros( (self.batch_size, self.max_number, *image_shape)) blend_list[s.name] = [] for i in range(self.batch_size): blend_images[s.name][i] = batch_results[i][0] isolated_images[s.name][i] = batch_results[i][1] blend_list[s.name].append(batch_results[i][2]) # save results if requested. if self.save_path is not None: if not os.path.exists(os.path.join(self.save_path, s.name)): os.mkdir(os.path.join(self.save_path, s.name)) np.save(os.path.join(self.save_path, s.name, "blended"), blend_images[s.name]) np.save(os.path.join(self.save_path, s.name, "isolated"), isolated_images[s.name]) for i in range(len(batch_results)): blend_list[s.name][i].write( os.path.join(self.save_path, s.name, f"blend_info_{i}"), format="ascii", overwrite=True, ) if self.is_multiresolution: output = { "blend_images": blend_images, "isolated_images": isolated_images, "blend_list": blend_list, "psf": psfs, "wcs": wcss, } else: survey_name = self.surveys[0].name output = { "blend_images": blend_images[survey_name], "isolated_images": isolated_images[survey_name], "blend_list": blend_list[survey_name], "psf": psfs[survey_name], "wcs": wcss[survey_name], } return output
def __next__(self): """ Returns: Dictionary with blend images, isolated object images, blend catalog, and observing conditions. """ batch_blend_cat, batch_obs_cond = {}, {} blend_images = {} isolated_images = {} for s in self.surveys: pix_stamp_size = int(self.stamp_size / s.pixel_scale) batch_blend_cat[s.name], batch_obs_cond[s.name] = [], [] blend_images[s.name] = np.zeros((self.batch_size, pix_stamp_size, pix_stamp_size, len(s.bands))) isolated_images[s.name] = np.zeros(( self.batch_size, self.max_number, pix_stamp_size, pix_stamp_size, len(s.bands), )) in_batch_blend_cat = next(self.blend_generator) obs_conds = next( self.observing_generator) # same for every blend in batch. mini_batch_size = np.max([self.batch_size // self.cpus, 1]) for s in self.surveys: input_args = [( copy.deepcopy(in_batch_blend_cat[i:i + mini_batch_size]), copy.deepcopy(obs_conds[s.name]), s, ) for i in range(0, self.batch_size, mini_batch_size)] # multiprocess and join results # ideally, each cpu processes a single mini_batch mini_batch_results = multiprocess( self.render_mini_batch, input_args, self.cpus, self.multiprocessing, self.verbose, ) # join results across mini-batches. batch_results = list(chain(*mini_batch_results)) # organize results. for i in range(self.batch_size): blend_images[s.name][i] = batch_results[i][0] isolated_images[s.name][i] = batch_results[i][1] batch_blend_cat[s.name].append(batch_results[i][2]) if len(self.surveys) > 1: output = { "blend_images": blend_images, "isolated_images": isolated_images, "blend_list": batch_blend_cat, "obs_condition": obs_conds, } else: survey_name = self.surveys[0].name output = { "blend_images": blend_images[survey_name], "isolated_images": isolated_images[survey_name], "blend_list": batch_blend_cat[survey_name], "obs_condition": obs_conds[survey_name], } return output
def __next__(self): """Return measurement results on a single batch from the draw_blend_generator. Returns: draw_blend_generator output from its `__next__` method. measurement_results (dict): Dictionary with keys being the name of each `measure_function` passed in. Each value is a dictionary containing keys `catalog`, `deblended_images`, and `segmentation` storing the values returned by the corresponding measure_function` for one batch. """ blend_output = next(self.draw_blend_generator) catalog = {} segmentation = {} deblended_images = {} for f in self.measure_functions: for m in range(len(self.measure_kwargs)): key_name = f.__name__ + str(m) if len( self.measure_kwargs) > 1 else f.__name__ catalog[key_name] = [] segmentation[key_name] = [] deblended_images[key_name] = [] for m, measure_kwargs in enumerate(self.measure_kwargs): args_iter = ((blend_output, i) for i in range(self.batch_size)) kwargs_iter = repeat(measure_kwargs) measure_output = multiprocess( self.run_batch, args_iter, kwargs_iter=kwargs_iter, cpus=self.cpus, verbose=self.verbose, ) if self.verbose: print(f"Measurement {m} performed on batch") for i, f in enumerate(self.measure_functions): key_name = f.__name__ + str(m) if len( self.measure_kwargs) > 1 else f.__name__ for j in range(len(measure_output)): catalog[key_name].append(measure_output[j][i].get( "catalog", None)) segmentation[key_name].append(measure_output[j][i].get( "segmentation", None)) deblended_images[key_name].append(measure_output[j][i].get( "deblended_images", None)) # If multiresolution, we reverse the order between the survey name and # the index of the blend if self.is_multiresolution: survey_keys = list(blend_output["blend_list"].keys()) # We duplicate the catalog for each survey to get the pixel coordinates catalogs_temp = {} for surv in survey_keys: catalogs_temp[surv] = add_pixel_columns( catalog[key_name], blend_output["wcs"][surv]) catalog[key_name] = catalogs_temp segmentation[key_name] = reverse_list_dictionary( segmentation[key_name], survey_keys) deblended_images[key_name] = reverse_list_dictionary( deblended_images[key_name], survey_keys) else: catalog[key_name] = add_pixel_columns( catalog[key_name], blend_output["wcs"]) # save results if requested. if self.save_path is not None: if not os.path.exists( os.path.join(self.save_path, key_name)): os.mkdir(os.path.join(self.save_path, key_name)) if segmentation[key_name] is not None: np.save( os.path.join(self.save_path, key_name, "segmentation"), segmentation[key_name], ) if deblended_images[key_name] is not None: np.save( os.path.join(self.save_path, key_name, "deblended_images"), deblended_images[key_name], ) for j, cat in enumerate(catalog[key_name]): cat.write( os.path.join(self.save_path, key_name, f"detection_catalog_{j}"), format="ascii", overwrite=True, ) measure_results = { "catalog": catalog, "segmentation": segmentation, "deblended_images": deblended_images, } return blend_output, measure_results