def test_ensemble_returned_values(): """ Check that the ensemble returns values that are reasonable and within expected bounds """ tm = corrections.TaskManager(INPUT_DIR) task = tm.get_task(starid=starid, camera=camera, ccd=ccd) #Initiate the class CorrClass = corrections.corrclass('ensemble') corr = CorrClass(INPUT_DIR, plot=False) inlc = corr.load_lightcurve(task) outlc, status = corr.do_correction(inlc.copy()) # Check input validation #with pytest.raises(ValueError) as err: # outlc, status = corr.do_correction('hello world') # assert('The input to `do_correction` is not a TessLightCurve object!' in err.value.args[0]) #C heck contents assert len(outlc.flux) == len( inlc.flux), "Input flux ix different length to output flux" assert all( inlc.time == outlc.time), "Input time is nonidentical to output time" assert all(outlc.flux != inlc.flux), "Input and output flux are identical." assert len(outlc.flux) == len( outlc.time), "Check time and flux have same length" # Check status assert status == corrections.STATUS.OK, "STATUS was not set appropriately"
def test_corrclass_type(): """Check that tesscorr.py returns the correct class""" CorrClass = corrclass() assert repr(CorrClass) == "<class 'corrections.ensemble.EnsembleCorrector'>" CorrClass = corrclass('ensemble') assert repr(CorrClass) == "<class 'corrections.ensemble.EnsembleCorrector'>" CorrClass = corrclass('cbv') assert repr(CorrClass) == "<class 'corrections.cbv_corrector.CBVCorrector.CBVCorrector'>" CorrClass = corrclass('kasoc_filter') assert repr(CorrClass) == "<class 'corrections.KASOCFilterCorrector.KASOCFilterCorrector'>" with pytest.raises(ValueError) as err: method = 'not-a-method' CorrClass = corrclass('method') assert err.value.args[0] == "Invalid method: '{0}'".format(method)
def test_corrclass_type(): """Check that tesscorr.py returns the correct class""" CorrClass = corrections.corrclass() print(CorrClass) assert CorrClass is corrections.EnsembleCorrector CorrClass = corrections.corrclass('ensemble') print(CorrClass) assert CorrClass is corrections.EnsembleCorrector CorrClass = corrections.corrclass('cbv') print(CorrClass) assert CorrClass is corrections.CBVCorrector CorrClass = corrections.corrclass('kasoc_filter') print(CorrClass) assert CorrClass is corrections.KASOCFilterCorrector method = 'not-a-method' with pytest.raises(ValueError) as err: CorrClass = corrections.corrclass(method) assert err.value.args[0] == "Invalid method: '{0}'".format(method)
def test_run_metadata(): """ Check that the ensemble returns values that are reasonable and within expected bounds """ tm = corrections.TaskManager(INPUT_DIR) task = tm.get_task(starid=starid, camera=camera, ccd=ccd) #Initiate the class CorrClass = corrections.corrclass('ensemble') corr = CorrClass(INPUT_DIR, plot=False) inlc = corr.load_lightcurve(task) outlc, status = corr.do_correction(inlc.copy()) # Check metadata #assert 'fmean' in outlc.meta, "Metadata is incomplete" #assert 'fstd' in outlc.meta, "Metadata is incomplete" #assert 'frange' in outlc.meta, "Metadata is incomplete" #assert 'drange' in outlc.meta, "Metadata is incomplete" assert outlc.meta['task']['starid'] == inlc.meta['task'][ 'starid'], "Metadata is incomplete" assert outlc.meta['task'] == inlc.meta['task'], "Metadata is incomplete"
def test_known_star(SHARED_INPUT_DIR, corrector, starid, cadence, var_goal, rms_goal, ptp_goal): """ Check that the ensemble returns values that are reasonable and within expected bounds """ # All stars we check here come from the same sector and camera. # Define these here for the future where we may test on other combinations of these: sector = 1 camera = 1 __dir__ = os.path.abspath(os.path.dirname(__file__)) logger = logging.getLogger(__name__) logger.info("-------------------------------------------------------------") logger.info("CORRECTOR = %s, SECTOR=%d, CADENCE=%s, STARID=%d", corrector, sector, cadence, starid) # All stars are from the same CCD, find the task for it: with corrections.TaskManager(SHARED_INPUT_DIR) as tm: task = tm.get_task(starid=starid, sector=sector, camera=camera, cadence=cadence) # Check that task was actually found: assert task is not None, "Task could not be found" # Load lightcurve that will also be plotted together with the result: # This lightcurve is of the same objects, at a state where it was deemed that the # corrections were doing a good job. compare_lc_path = os.path.join(__dir__, 'compare', f'compare-{corrector}-s{sector:04d}-c{cadence:04d}-tic{starid:011d}.ecsv.gz') compare_lc = None if os.path.isfile(compare_lc_path): compare_lc = Table.read(compare_lc_path, format='ascii.ecsv') else: warnings.warn("Comparison data does not exist: " + compare_lc_path) # Initiate the class CorrClass = corrections.corrclass(corrector) with tempfile.TemporaryDirectory() as tmpdir: with CorrClass(SHARED_INPUT_DIR, plot=True) as corr: # Check basic parameters of object (from BaseCorrector): assert corr.input_folder == SHARED_INPUT_DIR, "Incorrect input folder" assert corr.plot, "Plot parameter passed appropriately" assert os.path.isdir(corr.data_folder), "DATA_FOLDER doesn't exist" # Load the input lightcurve: inlc = corr.load_lightcurve(task) # Print input lightcurve properties: print( inlc.show_properties() ) assert inlc.sector == sector assert inlc.camera == camera # Run correction: tmplc = inlc.copy() outlc, status = corr.do_correction(tmplc) # Check status assert outlc is not None, "Correction fails" assert isinstance(outlc, TessLightCurve), "Should return TessLightCurve object" assert isinstance(status, corrections.STATUS), "Should return a STATUS object" assert status in (corrections.STATUS.OK, corrections.STATUS.WARNING), "STATUS was not set appropriately" # Print output lightcurve properties: print( outlc.show_properties() ) # Save the lightcurve to FITS file to be tested later on: save_file = corr.save_lightcurve(outlc, output_folder=tmpdir) # Check contents assert len(outlc) == len(inlc), "Input flux ix different length to output flux" assert isinstance(outlc.flux, np.ndarray), "FLUX is not a ndarray" assert isinstance(outlc.flux_err, np.ndarray), "FLUX_ERR is not a ndarray" assert isinstance(outlc.quality, np.ndarray), "QUALITY is not a ndarray" assert outlc.flux.dtype.type is inlc.flux.dtype.type, "FLUX changes dtype" assert outlc.flux_err.dtype.type is inlc.flux_err.dtype.type, "FLUX_ERR changes dtype" assert outlc.quality.dtype.type is inlc.quality.dtype.type, "QUALITY changes dtype" assert outlc.flux.shape == inlc.flux.shape, "FLUX changes shape" assert outlc.flux_err.shape == inlc.flux_err.shape, "FLUX_ERR changes shape" assert outlc.quality.shape == inlc.quality.shape, "QUALITY changes shape" # Plot output lightcurves: fig, (ax1, ax2, ax3) = plt.subplots(3, 1, squeeze=True, figsize=[10, 10]) ax1.plot(inlc.time, inlc.flux, lw=0.5) ax1.set_title(f"{corrector} - Sector {sector:d} - {cadence}s - TIC {starid:d}") if compare_lc: ax2.plot(compare_lc['time'], compare_lc['flux'], label='Compare', lw=0.5) ax3.axhline(0, lw=0.5, ls=':', color='0.7') ax3.plot(outlc.time, outlc.flux - compare_lc['flux'], lw=0.5) ax2.plot(outlc.time, outlc.flux, label='New', lw=0.5) ax1.set_ylabel('Flux [e/s]') ax1.minorticks_on() ax2.set_ylabel('Relative Flux [ppm]') ax2.minorticks_on() ax2.legend() ax3.set_ylabel('New - Compare [ppm]') ax3.set_xlabel('Time [TBJD]') ax3.minorticks_on() fig.savefig(os.path.join(__dir__, f'test-{corrector}-s{sector:04d}-c{cadence:04d}-tic{starid:011d}.png'), bbox_inches='tight') plt.close(fig) # Check things that are allowed to change: assert all(outlc.flux != inlc.flux), "Input and output flux are identical." assert not np.any(np.isinf(outlc.flux)), "FLUX contains Infinite" assert not np.any(np.isinf(outlc.flux_err)), "FLUX_ERR contains Infinite" assert np.sum(np.isnan(outlc.flux)) < 0.5*len(outlc), "More than half the lightcurve is NaN" assert allnan(outlc.flux_err[np.isnan(outlc.flux)]), "FLUX_ERR should be NaN where FLUX is" # TODO: Check that quality hasn't changed in ways that are not allowed: # - Only values defined in CorrectorQualityFlags # - No removal of flags already set assert all(outlc.quality >= 0) assert all(outlc.quality <= 128) assert all(outlc.quality >= inlc.quality) # Things that shouldn't chance from the corrections: assert outlc.targetid == inlc.targetid, "TARGETID has changed" assert outlc.label == inlc.label, "LABEL has changed" assert outlc.sector == inlc.sector, "SECTOR has changed" assert outlc.camera == inlc.camera, "CAMERA has changed" assert outlc.ccd == inlc.ccd, "CCD has changed" assert outlc.quality_bitmask == inlc.quality_bitmask, "QUALITY_BITMASK has changed" assert outlc.ra == inlc.ra, "RA has changed" assert outlc.dec == inlc.dec, "DEC has changed" assert outlc.mission == 'TESS', "MISSION has changed" assert outlc.time_format == 'btjd', "TIME_FORMAT has changed" assert outlc.time_scale == 'tdb', "TIME_SCALE has changed" assert_array_equal(outlc.time, inlc.time, "TIME has changed") assert_array_equal(outlc.timecorr, inlc.timecorr, "TIMECORR has changed") assert_array_equal(outlc.cadenceno, inlc.cadenceno, "CADENCENO has changed") assert_array_equal(outlc.pixel_quality, inlc.pixel_quality, "PIXEL_QUALITY has changed") assert_array_equal(outlc.centroid_col, inlc.centroid_col, "CENTROID_COL has changed") assert_array_equal(outlc.centroid_row, inlc.centroid_row, "CENTROID_ROW has changed") # Check metadata assert tmplc.meta == inlc.meta, "Correction changed METADATA in-place" assert outlc.meta['task'] == inlc.meta['task'], "Metadata is incomplete" assert isinstance(outlc.meta['additional_headers'], fits.Header) # Check performance metrics: #logger.warning("VAR: %e", nanvar(outlc.flux)) if var_goal is not None: var_in = nanvar(inlc.flux) var_out = nanvar(outlc.flux) var_diff = np.abs(var_out - var_goal) / var_goal logger.info("VAR: %f - %f - %f", var_in, var_out, var_diff) assert_array_less(var_diff, 0.05, "VARIANCE changed outside interval") #logger.warning("RMS: %e", rms_timescale(outlc)) if rms_goal is not None: rms_in = rms_timescale(inlc) rms_out = rms_timescale(outlc) rms_diff = np.abs(rms_out - rms_goal) / rms_goal logger.info("RMS: %f - %f - %f", rms_in, rms_out, rms_diff) assert_array_less(rms_diff, 0.05, "RMS changed outside interval") #logger.warning("PTP: %e", ptp(outlc)) if ptp_goal is not None: ptp_in = ptp(inlc) ptp_out = ptp(outlc) ptp_diff = np.abs(ptp_out - ptp_goal) / ptp_goal logger.info("PTP: %f - %f - %f", ptp_in, ptp_out, ptp_diff) assert_array_less(ptp_diff, 0.05, "PTP changed outside interval") # Check FITS file: with fits.open(os.path.join(tmpdir, save_file), mode='readonly') as hdu: # Lightcurve FITS table: fitslc = hdu['LIGHTCURVE'].data hdr = hdu['LIGHTCURVE'].header # Simple checks of header values: assert hdu[0].header['TICID'] == starid # Checks of things in FITS table that should not have changed at all: assert_array_equal(fitslc['TIME'], inlc.time, "FITS: TIME has changed") assert_array_equal(fitslc['TIMECORR'], inlc.timecorr, "FITS: TIMECORR has changed") assert_array_equal(fitslc['CADENCENO'], inlc.cadenceno, "FITS: CADENCENO has changed") assert_array_equal(fitslc['FLUX_RAW'], inlc.flux, "FITS: FLUX_RAW has changed") assert_array_equal(fitslc['FLUX_RAW_ERR'], inlc.flux_err, "FITS: FLUX_RAW_ERR has changed") assert_array_equal(fitslc['MOM_CENTR1'], inlc.centroid_col, "FITS: CENTROID_COL has changed") assert_array_equal(fitslc['MOM_CENTR2'], inlc.centroid_row, "FITS: CENTROID_ROW has changed") # Some things are allowed to change, but still within some requirements: assert all(fitslc['FLUX_CORR'] != inlc.flux), "FITS: Input and output flux are identical." assert np.sum(np.isnan(fitslc['FLUX_CORR'])) < 0.5*len(fitslc['TIME']), "FITS: More than half the lightcurve is NaN" assert allnan(fitslc['FLUX_CORR_ERR'][np.isnan(fitslc['FLUX_CORR'])]), "FITS: FLUX_ERR should be NaN where FLUX is" if corrector == 'ensemble': # Check special headers: assert np.isfinite(hdr['ENS_MED']) and hdr['ENS_MED'] > 0 assert isinstance(hdr['ENS_NUM'], int) and hdr['ENS_NUM'] > 0 assert hdr['ENS_DLIM'] == 1.0 assert hdr['ENS_DREL'] == 10.0 assert hdr['ENS_RLIM'] == 0.4 # Special extension for ensemble: tic = hdu['ENSEMBLE'].data['TIC'] bzeta = hdu['ENSEMBLE'].data['BZETA'] assert len(tic) == len(bzeta) assert len(np.unique(tic)) == len(tic), "TIC numbers in ENSEMBLE table are not unique" assert len(tic) == hdr['ENS_NUM'], "Not the same number of targets in ENSEMBLE table as specified in header" elif corrector == 'cbv': # Check special headers: assert isinstance(hdr['CBV_NUM'], int) and hdr['CBV_NUM'] > 0 # Check coefficients: for k in range(0, hdr['CBV_NUM']+1): assert np.isfinite(hdr['CBV_C%d' % k]) for k in range(1, hdr['CBV_NUM']+1): assert np.isfinite(hdr['CBVS_C%d' % k]) # Check that no other coefficients are present assert 'CBV_C%d' % (hdr['CBV_NUM']+1) not in hdr assert 'CBVS_C%d' % (hdr['CBV_NUM']+1) not in hdr elif corrector == 'kasoc_filter': # Check special headers: assert hdr['KF_POSS'] == 'None' assert np.isfinite(hdr['KF_LONG']) and hdr['KF_LONG'] > 0 assert np.isfinite(hdr['KF_SHORT']) and hdr['KF_SHORT'] > 0 assert hdr['KF_SCLIP'] == 4.5 assert hdr['KF_TCLIP'] == 5.0 assert hdr['KF_TWDTH'] == 1.0 assert hdr['KF_PSMTH'] == 200 assert isinstance(hdr['NUM_PER'], int) and hdr['NUM_PER'] >= 0 for k in range(1, hdr['NUM_PER']+1): assert np.isfinite(hdr['PER_%d' % k]) and hdr['PER_%d' % k] > 0 # Check that no other periods are present assert 'PER_%d' % (hdr['NUM_PER'] + 1) not in hdr # Test that the Gzip FITS file has the correct uncompressed file name, by simply # decompressing the Gzip file, asking to keep the original file name. # This uses the system GZIP utility, since there doesn't seem to be a way to do this # through the Python gzip module: fpath = os.path.join(tmpdir, save_file) fpath_uncompressed = fpath.replace('.fits.gz', '.fits') assert not os.path.exists(fpath_uncompressed), "Uncompressed file already exists" gzip_output = subprocess.check_output(['gzip', '-dkNv', os.path.basename(fpath)], cwd=os.path.dirname(fpath), stderr=subprocess.STDOUT, encoding='utf8') print("Gzip output:") print(gzip_output) assert os.path.isfile(fpath_uncompressed), "Incorrect uncompressed file name" # Just see if we can in fact also open the uncompressed FITS file and get a simple header: with fits.open(fpath_uncompressed, mode='readonly') as hdu: assert hdu[0].header['TICID'] == starid
def main(): # Parse command line arguments: parser = argparse.ArgumentParser( description='Run TESS Corrections in parallel using MPI.') #parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') #parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') parser.add_argument('-m', '--method', help='Corrector method to use.', default=None, choices=('ensemble', 'cbv', 'kasoc_filter')) parser.add_argument('-o', '--overwrite', help='Overwrite existing results.', action='store_true') parser.add_argument('-p', '--plot', help='Save plots when running.', action='store_true') parser.add_argument('--camera', type=int, choices=(1, 2, 3, 4), default=None, help='TESS Camera. Default is to run all cameras.') parser.add_argument('--ccd', type=int, choices=(1, 2, 3, 4), default=None, help='TESS CCD. Default is to run all CCDs.') parser.add_argument('--datasource', type=str, choices=('ffi', 'tpf'), default='ffi', help='Data source or cadence.') parser.add_argument( 'input_folder', type=str, help= 'Input directory. This directory should contain a TODO-file and corresponding lightcurves.', nargs='?', default=None) args = parser.parse_args() # Get input and output folder from environment variables: input_folder = args.input_folder if input_folder is None: input_folder = os.environ.get('TESSCORR_INPUT') if not input_folder: parser.error("Please specify an INPUT_FOLDER.") output_folder = os.environ.get('TESSCORR_OUTPUT', os.path.join(input_folder, 'lightcurves')) # Define MPI message tags tags = enum.IntEnum('tags', ('READY', 'DONE', 'EXIT', 'START')) # Initializations and preliminaries comm = MPI.COMM_WORLD # get MPI communicator object size = comm.size # total number of processes rank = comm.rank # rank of this process status = MPI.Status() # get MPI status object if rank == 0: try: with corrections.TaskManager(input_folder, cleanup=True, overwrite=args.overwrite, summary=os.path.join( output_folder, 'summary_corr.json')) as tm: # Get list of tasks: numtasks = tm.get_number_tasks(camera=args.camera, ccd=args.ccd, datasource=args.datasource) tm.logger.info("%d tasks to be run", numtasks) # Start the master loop that will assign tasks # to the workers: num_workers = size - 1 closed_workers = 0 tm.logger.info("Master starting with %d workers", num_workers) while closed_workers < num_workers: # Ask workers for information: data = comm.recv(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status) source = status.Get_source() tag = status.Get_tag() if tag == tags.DONE: # The worker is done with a task tm.logger.info("Got data from worker %d: %s", source, data) tm.save_results(data) if tag in (tags.DONE, tags.READY): # Worker is ready, so send it a task task = tm.get_task(camera=args.camera, ccd=args.ccd, datasource=args.datasource) if task: task_index = task['priority'] tm.start_task(task_index) comm.send(task, dest=source, tag=tags.START) tm.logger.info("Sending task %d to worker %d", task_index, source) else: comm.send(None, dest=source, tag=tags.EXIT) elif tag == tags.EXIT: # The worker has exited tm.logger.info("Worker %d exited.", source) closed_workers += 1 else: # This should never happen, but just to # make sure we don't run into an infinite loop: raise Exception( "Master received an unknown tag: '{0}'".format( tag)) tm.logger.info("Master finishing") except: # If something fails in the master print(traceback.format_exc().strip()) comm.Abort(1) else: # Worker processes execute code below # Configure logging within photometry: formatter = logging.Formatter( '%(asctime)s - %(levelname)s - %(message)s') console = logging.StreamHandler() console.setFormatter(formatter) logger = logging.getLogger('corrections') logger.addHandler(console) logger.setLevel(logging.WARNING) # Get the class for the selected method: CorrClass = corrections.corrclass(args.method) try: with CorrClass(input_folder, plot=args.plot) as corr: # Send signal that we are ready for task: comm.send(None, dest=0, tag=tags.READY) while True: # Receive a task from the master: tic = default_timer() task = comm.recv(source=0, tag=MPI.ANY_TAG, status=status) tag = status.Get_tag() toc = default_timer() if tag == tags.START: result = task.copy() # Run the correction: try: result = corr.correct(task) except: # Something went wrong error_msg = traceback.format_exc().strip() result.update({ 'status_corr': corrections.STATUS.ERROR, 'details': { 'errors': error_msg }, }) result.update({'worker_wait_time': toc - tic}) # Send the result back to the master: comm.send(result, dest=0, tag=tags.DONE) # Attempt some cleanup: # TODO: Is this even needed? del task, result elif tag == tags.EXIT: # We were told to EXIT, so lets do that break else: # This should never happen, but just to # make sure we don't run into an infinite loop: raise Exception( "Worker received an unknown tag: '{0}'".format( tag)) except: logger.exception("Something failed in worker") finally: comm.send(None, dest=0, tag=tags.EXIT)
if input_folder is None: test_folder = os.path.abspath( os.path.join(os.path.dirname(__file__), 'tests', 'input')) if args.test: input_folder = test_folder else: input_folder = os.environ.get('TESSCORR_INPUT', test_folder) output_folder = os.environ.get('TESSCORR_OUTPUT', os.path.join(input_folder, 'lightcurves')) logger.info("Loading input data from '%s'", input_folder) logger.info("Putting output data in '%s'", output_folder) # Get the class for the selected method: CorrClass = corrections.corrclass(args.method) # Initialize the corrector class: with CorrClass(input_folder, plot=args.plot) as corr: # Start the TaskManager: with corrections.TaskManager(input_folder) as tm: while True: if args.all: task = tm.get_task() if task is None: break elif args.starid is not None: task = tm.get_task(starid=args.starid) elif args.random: task = tm.get_random_task()
def main(): # Parse command line arguments: parser = argparse.ArgumentParser( description='Run TESS Corrections in parallel using MPI.') parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') parser.add_argument('-m', '--method', help='Corrector method to use.', default='cbv', choices=('ensemble', 'cbv', 'kasoc_filter')) parser.add_argument('-o', '--overwrite', help='Overwrite existing results.', action='store_true') parser.add_argument('-p', '--plot', help='Save plots when running.', action='store_true') group = parser.add_argument_group('Filter which targets to process') group.add_argument('--sector', type=int, default=None, help='TESS Sector.') group.add_argument('--cadence', type=CadenceType, choices=('ffi', 1800, 600, 120, 20), default=None, help='Cadence. Default is to run all.') group.add_argument('--camera', type=int, choices=(1, 2, 3, 4), default=None, help='TESS Camera. Default is to run all cameras.') group.add_argument('--ccd', type=int, choices=(1, 2, 3, 4), default=None, help='TESS CCD. Default is to run all CCDs.') parser.add_argument( 'input_folder', type=str, help= 'Input directory. This directory should contain a TODO-file and corresponding lightcurves.', nargs='?', default=None) parser.add_argument('output_folder', type=str, help='Directory to save output in.', nargs='?', default=None) args = parser.parse_args() # Set logging level: logging_level = logging.INFO if args.quiet: logging_level = logging.WARNING elif args.debug: logging_level = logging.DEBUG # Get input and output folder from environment variables: input_folder = args.input_folder if input_folder is None: input_folder = os.environ.get('TESSCORR_INPUT') if not input_folder: parser.error("Please specify an INPUT_FOLDER.") output_folder = args.output_folder if output_folder is None: output_folder = os.environ.get( 'TESSCORR_OUTPUT', os.path.join(os.path.dirname(input_folder), 'lightcurves')) # Define MPI message tags tags = enum.IntEnum('tags', ('INIT', 'READY', 'DONE', 'EXIT', 'START')) # Initializations and preliminaries comm = MPI.COMM_WORLD # get MPI communicator object size = comm.size # total number of processes rank = comm.rank # rank of this process status = MPI.Status() # get MPI status object if rank == 0: try: # Constraints on which targets to process: constraints = { 'sector': args.sector, 'cadence': args.cadence, 'camera': args.camera, 'ccd': args.ccd } # File path to write summary to: summary_file = os.path.join(output_folder, f'summary_corr_{args.method:s}.json') # Invoke the TaskManager to ensure that the input TODO-file has the correct columns # and indicies, which is automatically created by the TaskManager init function. with corrections.TaskManager(input_folder, cleanup=True, overwrite=args.overwrite, cleanup_constraints=constraints): pass # Signal that workers are free to initialize: comm.Barrier() # Barrier 1 # Wait for all workers to initialize: comm.Barrier() # Barrier 2 # Start TaskManager, which keeps track of the task that needs to be performed: with corrections.TaskManager(input_folder, overwrite=args.overwrite, cleanup_constraints=constraints, summary=summary_file) as tm: # Set level of TaskManager logger: tm.logger.setLevel(logging_level) # Get list of tasks: numtasks = tm.get_number_tasks(**constraints) tm.logger.info("%d tasks to be run", numtasks) # Start the master loop that will assign tasks # to the workers: num_workers = size - 1 closed_workers = 0 tm.logger.info("Master starting with %d workers", num_workers) while closed_workers < num_workers: # Get information from worker: data = comm.recv(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status) source = status.Get_source() tag = status.Get_tag() if tag == tags.DONE: # The worker is done with a task tm.logger.debug("Got data from worker %d: %s", source, data) tm.save_results(data) if tag in (tags.DONE, tags.READY): # Worker is ready for a new task, so send it a task tasks = tm.get_task(**constraints, chunk=10) if tasks: tm.start_task(tasks) tm.logger.debug("Sending %d tasks to worker %d", len(tasks), source) comm.send(tasks, dest=source, tag=tags.START) else: comm.send(None, dest=source, tag=tags.EXIT) elif tag == tags.EXIT: # The worker has exited tm.logger.info("Worker %d exited.", source) closed_workers += 1 else: # This should never happen, but just to # make sure we don't run into an infinite loop: raise RuntimeError( f"Master received an unknown tag: '{tag}'") tm.logger.info("Master finishing") except: # noqa: E722, pragma: no cover # If something fails in the master print(traceback.format_exc().strip()) comm.Abort(1) else: # Worker processes execute code below # Configure logging within photometry: formatter = logging.Formatter( '%(asctime)s - %(levelname)s - %(message)s') console = logging.StreamHandler() console.setFormatter(formatter) logger = logging.getLogger('corrections') logger.addHandler(console) logger.setLevel(logging.WARNING) # Get the class for the selected method: CorrClass = corrections.corrclass(args.method) try: # Wait for signal that we are okay to initialize: comm.Barrier() # Barrier 1 # We can now safely initialize the corrector on the input file: with CorrClass(input_folder, plot=args.plot) as corr: # Wait for all workers do be done initializing: comm.Barrier() # Barrier 2 # Send signal that we are ready for task: comm.send(None, dest=0, tag=tags.READY) while True: # Receive a task from the master: tic = default_timer() tasks = comm.recv(source=0, tag=MPI.ANY_TAG, status=status) tag = status.Get_tag() toc = default_timer() if tag == tags.START: # Make sure we can loop through tasks, # even in the case we have only gotten one: results = [] if not isinstance(tasks, (list, tuple)): tasks = list(tasks) # Loop through the tasks given to us: for task in tasks: result = task.copy() # Run the correction: try: result = corr.correct(task) except: # noqa: E722 # Something went wrong error_msg = traceback.format_exc().strip() result.update({ 'status_corr': corrections.STATUS.ERROR, 'details': { 'errors': [error_msg] }, }) result.update({'worker_wait_time': toc - tic}) results.append(result) # Send the result back to the master: comm.send(results, dest=0, tag=tags.DONE) elif tag == tags.EXIT: # We were told to EXIT, so lets do that break else: # This should never happen, but just to # make sure we don't run into an infinite loop: raise RuntimeError( f"Worker received an unknown tag: '{tag}'") except: # noqa: E722, pragma: no cover logger.exception("Something failed in worker") finally: comm.send(None, dest=0, tag=tags.EXIT)
def main(): # Parse command line arguments: parser = argparse.ArgumentParser( description='Run TESS Corrector pipeline on single star.') parser.add_argument('-m', '--method', help='Corrector method to use.', default=None, choices=('ensemble', 'cbv', 'kasoc_filter')) parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') parser.add_argument('-p', '--plot', help='Save plots when running.', action='store_true') parser.add_argument('-r', '--random', help='Run on random target from TODO-list.', action='store_true') parser.add_argument( '-t', '--test', help='Use test data and ignore TESSCORR_INPUT environment variable.', action='store_true') parser.add_argument('-a', '--all', help='Run correction on all targets.', action='store_true') parser.add_argument('-o', '--overwrite', help='Overwrite previous runs and start over.', action='store_true') group = parser.add_argument_group('Filter which targets to process') group.add_argument('--priority', type=int, default=None, help='Priority of target.') group.add_argument('--starid', type=int, default=None, help='TIC identifier of target.') group.add_argument('--sector', type=int, default=None, help='TESS Sector.') group.add_argument('--cadence', type=CadenceType, choices=('ffi', 1800, 600, 120, 20), default=None, help='Cadence. Default is to run all.') group.add_argument('--camera', type=int, choices=(1, 2, 3, 4), default=None, help='TESS Camera. Default is to run all cameras.') group.add_argument('--ccd', type=int, choices=(1, 2, 3, 4), default=None, help='TESS CCD. Default is to run all CCDs.') parser.add_argument( 'input_folder', type=str, help= 'Input directory. This directory should contain a TODO-file and corresponding lightcurves.', nargs='?', default=None) parser.add_argument('output_folder', type=str, help='Directory to save output in.', nargs='?', default=None) args = parser.parse_args() # Make sure at least one setting is given: if not args.all and args.starid is None and args.priority is None and not args.random: parser.error( "Please select either a specific STARID, PRIORITY or RANDOM.") # Set logging level: logging_level = logging.INFO if args.quiet: logging_level = logging.WARNING elif args.debug: logging_level = logging.DEBUG # Setup logging: formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') console = logging.StreamHandler() console.setFormatter(formatter) logger = logging.getLogger(__name__) if not logger.hasHandlers(): logger.addHandler(console) logger.setLevel(logging_level) logger_parent = logging.getLogger('corrections') if not logger_parent.hasHandlers(): logger_parent.addHandler(console) logger_parent.setLevel(logging_level) # Get input and output folder from environment variables: input_folder = args.input_folder if input_folder is None: test_folder = os.path.abspath( os.path.join(os.path.dirname(__file__), 'tests', 'input')) if args.test: input_folder = test_folder else: input_folder = os.environ.get('TESSCORR_INPUT', test_folder) output_folder = args.output_folder if output_folder is None: output_folder = os.environ.get( 'TESSCORR_OUTPUT', os.path.join(os.path.dirname(input_folder), 'lightcurves')) logger.info("Loading input data from '%s'", input_folder) logger.info("Putting output data in '%s'", output_folder) # Make sure the output directory exists: os.makedirs(output_folder, exist_ok=True) # Constraints on which targets to process: constraints = { 'priority': args.priority, 'starid': args.starid, 'sector': args.sector, 'cadence': args.cadence, 'camera': args.camera, 'ccd': args.ccd } # Get the class for the selected method: CorrClass = corrections.corrclass(args.method) # Start the TaskManager: with corrections.TaskManager(input_folder, overwrite=args.overwrite, cleanup_constraints=constraints) as tm: # Initialize the corrector class: with CorrClass(input_folder, plot=args.plot) as corr: while True: if args.random: task = tm.get_random_task() else: task = tm.get_task(**constraints) if task is None: break # Run the correction: result = corr.correct(task, output_folder=output_folder) # Construct results to return to TaskManager: tm.save_results(result) if not args.all: break