Example #1
0
def test_ensemble_returned_values():
    """ Check that the ensemble returns values that are reasonable and within expected bounds """
    tm = corrections.TaskManager(INPUT_DIR)
    task = tm.get_task(starid=starid, camera=camera, ccd=ccd)

    #Initiate the class
    CorrClass = corrections.corrclass('ensemble')
    corr = CorrClass(INPUT_DIR, plot=False)
    inlc = corr.load_lightcurve(task)
    outlc, status = corr.do_correction(inlc.copy())

    # Check input validation
    #with pytest.raises(ValueError) as err:
    #	outlc, status = corr.do_correction('hello world')
    #	assert('The input to `do_correction` is not a TessLightCurve object!' in err.value.args[0])

    #C heck contents
    assert len(outlc.flux) == len(
        inlc.flux), "Input flux ix different length to output flux"
    assert all(
        inlc.time == outlc.time), "Input time is nonidentical to output time"
    assert all(outlc.flux != inlc.flux), "Input and output flux are identical."
    assert len(outlc.flux) == len(
        outlc.time), "Check time and flux have same length"

    # Check status
    assert status == corrections.STATUS.OK, "STATUS was not set appropriately"
Example #2
0
def test_corrclass_type():
    """Check that tesscorr.py returns the correct class"""
    CorrClass = corrclass()
    assert repr(CorrClass) == "<class 'corrections.ensemble.EnsembleCorrector'>"

    CorrClass = corrclass('ensemble')
    assert repr(CorrClass) == "<class 'corrections.ensemble.EnsembleCorrector'>"

    CorrClass = corrclass('cbv')
    assert repr(CorrClass) == "<class 'corrections.cbv_corrector.CBVCorrector.CBVCorrector'>"

    CorrClass = corrclass('kasoc_filter')
    assert repr(CorrClass) == "<class 'corrections.KASOCFilterCorrector.KASOCFilterCorrector'>"

    with pytest.raises(ValueError) as err:
        method = 'not-a-method'
        CorrClass = corrclass('method')
        assert err.value.args[0] == "Invalid method: '{0}'".format(method)
Example #3
0
def test_corrclass_type():
	"""Check that tesscorr.py returns the correct class"""

	CorrClass = corrections.corrclass()
	print(CorrClass)
	assert CorrClass is corrections.EnsembleCorrector

	CorrClass = corrections.corrclass('ensemble')
	print(CorrClass)
	assert CorrClass is corrections.EnsembleCorrector

	CorrClass = corrections.corrclass('cbv')
	print(CorrClass)
	assert CorrClass is corrections.CBVCorrector

	CorrClass = corrections.corrclass('kasoc_filter')
	print(CorrClass)
	assert CorrClass is corrections.KASOCFilterCorrector

	method = 'not-a-method'
	with pytest.raises(ValueError) as err:
		CorrClass = corrections.corrclass(method)
	assert err.value.args[0] == "Invalid method: '{0}'".format(method)
Example #4
0
def test_run_metadata():
    """ Check that the ensemble returns values that are reasonable and within expected bounds """
    tm = corrections.TaskManager(INPUT_DIR)
    task = tm.get_task(starid=starid, camera=camera, ccd=ccd)

    #Initiate the class
    CorrClass = corrections.corrclass('ensemble')
    corr = CorrClass(INPUT_DIR, plot=False)
    inlc = corr.load_lightcurve(task)
    outlc, status = corr.do_correction(inlc.copy())

    # Check metadata
    #assert 'fmean' in outlc.meta, "Metadata is incomplete"
    #assert 'fstd' in outlc.meta, "Metadata is incomplete"
    #assert 'frange' in outlc.meta, "Metadata is incomplete"
    #assert 'drange' in outlc.meta, "Metadata is incomplete"
    assert outlc.meta['task']['starid'] == inlc.meta['task'][
        'starid'], "Metadata is incomplete"
    assert outlc.meta['task'] == inlc.meta['task'], "Metadata is incomplete"
Example #5
0
def test_known_star(SHARED_INPUT_DIR, corrector, starid, cadence, var_goal, rms_goal, ptp_goal):
	""" Check that the ensemble returns values that are reasonable and within expected bounds """

	# All stars we check here come from the same sector and camera.
	# Define these here for the future where we may test on other combinations of these:
	sector = 1
	camera = 1

	__dir__ = os.path.abspath(os.path.dirname(__file__))
	logger = logging.getLogger(__name__)
	logger.info("-------------------------------------------------------------")
	logger.info("CORRECTOR = %s, SECTOR=%d, CADENCE=%s, STARID=%d", corrector, sector, cadence, starid)

	# All stars are from the same CCD, find the task for it:
	with corrections.TaskManager(SHARED_INPUT_DIR) as tm:
		task = tm.get_task(starid=starid, sector=sector, camera=camera, cadence=cadence)

	# Check that task was actually found:
	assert task is not None, "Task could not be found"

	# Load lightcurve that will also be plotted together with the result:
	# This lightcurve is of the same objects, at a state where it was deemed that the
	# corrections were doing a good job.
	compare_lc_path = os.path.join(__dir__, 'compare', f'compare-{corrector}-s{sector:04d}-c{cadence:04d}-tic{starid:011d}.ecsv.gz')
	compare_lc = None
	if os.path.isfile(compare_lc_path):
		compare_lc = Table.read(compare_lc_path, format='ascii.ecsv')
	else:
		warnings.warn("Comparison data does not exist: " + compare_lc_path)

	# Initiate the class
	CorrClass = corrections.corrclass(corrector)
	with tempfile.TemporaryDirectory() as tmpdir:
		with CorrClass(SHARED_INPUT_DIR, plot=True) as corr:
			# Check basic parameters of object (from BaseCorrector):
			assert corr.input_folder == SHARED_INPUT_DIR, "Incorrect input folder"
			assert corr.plot, "Plot parameter passed appropriately"
			assert os.path.isdir(corr.data_folder), "DATA_FOLDER doesn't exist"

			# Load the input lightcurve:
			inlc = corr.load_lightcurve(task)

			# Print input lightcurve properties:
			print( inlc.show_properties() )
			assert inlc.sector == sector
			assert inlc.camera == camera

			# Run correction:
			tmplc = inlc.copy()
			outlc, status = corr.do_correction(tmplc)

			# Check status
			assert outlc is not None, "Correction fails"
			assert isinstance(outlc, TessLightCurve), "Should return TessLightCurve object"
			assert isinstance(status, corrections.STATUS), "Should return a STATUS object"
			assert status in (corrections.STATUS.OK, corrections.STATUS.WARNING), "STATUS was not set appropriately"

			# Print output lightcurve properties:
			print( outlc.show_properties() )

			# Save the lightcurve to FITS file to be tested later on:
			save_file = corr.save_lightcurve(outlc, output_folder=tmpdir)

		# Check contents
		assert len(outlc) == len(inlc), "Input flux ix different length to output flux"
		assert isinstance(outlc.flux, np.ndarray), "FLUX is not a ndarray"
		assert isinstance(outlc.flux_err, np.ndarray), "FLUX_ERR is not a ndarray"
		assert isinstance(outlc.quality, np.ndarray), "QUALITY is not a ndarray"
		assert outlc.flux.dtype.type is inlc.flux.dtype.type, "FLUX changes dtype"
		assert outlc.flux_err.dtype.type is inlc.flux_err.dtype.type, "FLUX_ERR changes dtype"
		assert outlc.quality.dtype.type is inlc.quality.dtype.type, "QUALITY changes dtype"
		assert outlc.flux.shape == inlc.flux.shape, "FLUX changes shape"
		assert outlc.flux_err.shape == inlc.flux_err.shape, "FLUX_ERR changes shape"
		assert outlc.quality.shape == inlc.quality.shape, "QUALITY changes shape"

		# Plot output lightcurves:
		fig, (ax1, ax2, ax3) = plt.subplots(3, 1, squeeze=True, figsize=[10, 10])
		ax1.plot(inlc.time, inlc.flux, lw=0.5)
		ax1.set_title(f"{corrector} - Sector {sector:d} - {cadence}s - TIC {starid:d}")
		if compare_lc:
			ax2.plot(compare_lc['time'], compare_lc['flux'], label='Compare', lw=0.5)
			ax3.axhline(0, lw=0.5, ls=':', color='0.7')
			ax3.plot(outlc.time, outlc.flux - compare_lc['flux'], lw=0.5)
		ax2.plot(outlc.time, outlc.flux, label='New', lw=0.5)
		ax1.set_ylabel('Flux [e/s]')
		ax1.minorticks_on()
		ax2.set_ylabel('Relative Flux [ppm]')
		ax2.minorticks_on()
		ax2.legend()
		ax3.set_ylabel('New - Compare [ppm]')
		ax3.set_xlabel('Time [TBJD]')
		ax3.minorticks_on()
		fig.savefig(os.path.join(__dir__, f'test-{corrector}-s{sector:04d}-c{cadence:04d}-tic{starid:011d}.png'), bbox_inches='tight')
		plt.close(fig)

		# Check things that are allowed to change:
		assert all(outlc.flux != inlc.flux), "Input and output flux are identical."
		assert not np.any(np.isinf(outlc.flux)), "FLUX contains Infinite"
		assert not np.any(np.isinf(outlc.flux_err)), "FLUX_ERR contains Infinite"
		assert np.sum(np.isnan(outlc.flux)) < 0.5*len(outlc), "More than half the lightcurve is NaN"
		assert allnan(outlc.flux_err[np.isnan(outlc.flux)]), "FLUX_ERR should be NaN where FLUX is"

		# TODO: Check that quality hasn't changed in ways that are not allowed:
		# - Only values defined in CorrectorQualityFlags
		# - No removal of flags already set
		assert all(outlc.quality >= 0)
		assert all(outlc.quality <= 128)
		assert all(outlc.quality >= inlc.quality)

		# Things that shouldn't chance from the corrections:
		assert outlc.targetid == inlc.targetid, "TARGETID has changed"
		assert outlc.label == inlc.label, "LABEL has changed"
		assert outlc.sector == inlc.sector, "SECTOR has changed"
		assert outlc.camera == inlc.camera, "CAMERA has changed"
		assert outlc.ccd == inlc.ccd, "CCD has changed"
		assert outlc.quality_bitmask == inlc.quality_bitmask, "QUALITY_BITMASK has changed"
		assert outlc.ra == inlc.ra, "RA has changed"
		assert outlc.dec == inlc.dec, "DEC has changed"
		assert outlc.mission == 'TESS', "MISSION has changed"
		assert outlc.time_format == 'btjd', "TIME_FORMAT has changed"
		assert outlc.time_scale == 'tdb', "TIME_SCALE has changed"
		assert_array_equal(outlc.time, inlc.time, "TIME has changed")
		assert_array_equal(outlc.timecorr, inlc.timecorr, "TIMECORR has changed")
		assert_array_equal(outlc.cadenceno, inlc.cadenceno, "CADENCENO has changed")
		assert_array_equal(outlc.pixel_quality, inlc.pixel_quality, "PIXEL_QUALITY has changed")
		assert_array_equal(outlc.centroid_col, inlc.centroid_col, "CENTROID_COL has changed")
		assert_array_equal(outlc.centroid_row, inlc.centroid_row, "CENTROID_ROW has changed")

		# Check metadata
		assert tmplc.meta == inlc.meta, "Correction changed METADATA in-place"
		assert outlc.meta['task'] == inlc.meta['task'], "Metadata is incomplete"
		assert isinstance(outlc.meta['additional_headers'], fits.Header)

		# Check performance metrics:
		#logger.warning("VAR: %e", nanvar(outlc.flux))
		if var_goal is not None:
			var_in = nanvar(inlc.flux)
			var_out = nanvar(outlc.flux)
			var_diff = np.abs(var_out - var_goal) / var_goal
			logger.info("VAR: %f - %f - %f", var_in, var_out, var_diff)
			assert_array_less(var_diff, 0.05, "VARIANCE changed outside interval")

		#logger.warning("RMS: %e", rms_timescale(outlc))
		if rms_goal is not None:
			rms_in = rms_timescale(inlc)
			rms_out = rms_timescale(outlc)
			rms_diff = np.abs(rms_out - rms_goal) / rms_goal
			logger.info("RMS: %f - %f - %f", rms_in, rms_out, rms_diff)
			assert_array_less(rms_diff, 0.05, "RMS changed outside interval")

		#logger.warning("PTP: %e", ptp(outlc))
		if ptp_goal is not None:
			ptp_in = ptp(inlc)
			ptp_out = ptp(outlc)
			ptp_diff = np.abs(ptp_out - ptp_goal) / ptp_goal
			logger.info("PTP: %f - %f - %f", ptp_in, ptp_out, ptp_diff)
			assert_array_less(ptp_diff, 0.05, "PTP changed outside interval")

		# Check FITS file:
		with fits.open(os.path.join(tmpdir, save_file), mode='readonly') as hdu:
			# Lightcurve FITS table:
			fitslc = hdu['LIGHTCURVE'].data
			hdr = hdu['LIGHTCURVE'].header

			# Simple checks of header values:
			assert hdu[0].header['TICID'] == starid

			# Checks of things in FITS table that should not have changed at all:
			assert_array_equal(fitslc['TIME'], inlc.time, "FITS: TIME has changed")
			assert_array_equal(fitslc['TIMECORR'], inlc.timecorr, "FITS: TIMECORR has changed")
			assert_array_equal(fitslc['CADENCENO'], inlc.cadenceno, "FITS: CADENCENO has changed")
			assert_array_equal(fitslc['FLUX_RAW'], inlc.flux, "FITS: FLUX_RAW has changed")
			assert_array_equal(fitslc['FLUX_RAW_ERR'], inlc.flux_err, "FITS: FLUX_RAW_ERR has changed")
			assert_array_equal(fitslc['MOM_CENTR1'], inlc.centroid_col, "FITS: CENTROID_COL has changed")
			assert_array_equal(fitslc['MOM_CENTR2'], inlc.centroid_row, "FITS: CENTROID_ROW has changed")

			# Some things are allowed to change, but still within some requirements:
			assert all(fitslc['FLUX_CORR'] != inlc.flux), "FITS: Input and output flux are identical."
			assert np.sum(np.isnan(fitslc['FLUX_CORR'])) < 0.5*len(fitslc['TIME']), "FITS: More than half the lightcurve is NaN"
			assert allnan(fitslc['FLUX_CORR_ERR'][np.isnan(fitslc['FLUX_CORR'])]), "FITS: FLUX_ERR should be NaN where FLUX is"

			if corrector == 'ensemble':
				# Check special headers:
				assert np.isfinite(hdr['ENS_MED']) and hdr['ENS_MED'] > 0
				assert isinstance(hdr['ENS_NUM'], int) and hdr['ENS_NUM'] > 0
				assert hdr['ENS_DLIM'] == 1.0
				assert hdr['ENS_DREL'] == 10.0
				assert hdr['ENS_RLIM'] == 0.4

				# Special extension for ensemble:
				tic = hdu['ENSEMBLE'].data['TIC']
				bzeta = hdu['ENSEMBLE'].data['BZETA']
				assert len(tic) == len(bzeta)
				assert len(np.unique(tic)) == len(tic), "TIC numbers in ENSEMBLE table are not unique"
				assert len(tic) == hdr['ENS_NUM'], "Not the same number of targets in ENSEMBLE table as specified in header"

			elif corrector == 'cbv':
				# Check special headers:
				assert isinstance(hdr['CBV_NUM'], int) and hdr['CBV_NUM'] > 0

				# Check coefficients:
				for k in range(0, hdr['CBV_NUM']+1):
					assert np.isfinite(hdr['CBV_C%d' % k])
				for k in range(1, hdr['CBV_NUM']+1):
					assert np.isfinite(hdr['CBVS_C%d' % k])
				# Check that no other coefficients are present
				assert 'CBV_C%d' % (hdr['CBV_NUM']+1) not in hdr
				assert 'CBVS_C%d' % (hdr['CBV_NUM']+1) not in hdr

			elif corrector == 'kasoc_filter':
				# Check special headers:
				assert hdr['KF_POSS'] == 'None'
				assert np.isfinite(hdr['KF_LONG']) and hdr['KF_LONG'] > 0
				assert np.isfinite(hdr['KF_SHORT']) and hdr['KF_SHORT'] > 0
				assert hdr['KF_SCLIP'] == 4.5
				assert hdr['KF_TCLIP'] == 5.0
				assert hdr['KF_TWDTH'] == 1.0
				assert hdr['KF_PSMTH'] == 200

				assert isinstance(hdr['NUM_PER'], int) and hdr['NUM_PER'] >= 0
				for k in range(1, hdr['NUM_PER']+1):
					assert np.isfinite(hdr['PER_%d' % k]) and hdr['PER_%d' % k] > 0
				# Check that no other periods are present
				assert 'PER_%d' % (hdr['NUM_PER'] + 1) not in hdr

		# Test that the Gzip FITS file has the correct uncompressed file name, by simply
		# decompressing the Gzip file, asking to keep the original file name.
		# This uses the system GZIP utility, since there doesn't seem to be a way to do this
		# through the Python gzip module:
		fpath = os.path.join(tmpdir, save_file)
		fpath_uncompressed = fpath.replace('.fits.gz', '.fits')
		assert not os.path.exists(fpath_uncompressed), "Uncompressed file already exists"
		gzip_output = subprocess.check_output(['gzip', '-dkNv', os.path.basename(fpath)],
			cwd=os.path.dirname(fpath),
			stderr=subprocess.STDOUT,
			encoding='utf8')
		print("Gzip output:")
		print(gzip_output)
		assert os.path.isfile(fpath_uncompressed), "Incorrect uncompressed file name"

		# Just see if we can in fact also open the uncompressed FITS file and get a simple header:
		with fits.open(fpath_uncompressed, mode='readonly') as hdu:
			assert hdu[0].header['TICID'] == starid
Example #6
0
def main():
    # Parse command line arguments:
    parser = argparse.ArgumentParser(
        description='Run TESS Corrections in parallel using MPI.')
    #parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true')
    #parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true')
    parser.add_argument('-m',
                        '--method',
                        help='Corrector method to use.',
                        default=None,
                        choices=('ensemble', 'cbv', 'kasoc_filter'))
    parser.add_argument('-o',
                        '--overwrite',
                        help='Overwrite existing results.',
                        action='store_true')
    parser.add_argument('-p',
                        '--plot',
                        help='Save plots when running.',
                        action='store_true')
    parser.add_argument('--camera',
                        type=int,
                        choices=(1, 2, 3, 4),
                        default=None,
                        help='TESS Camera. Default is to run all cameras.')
    parser.add_argument('--ccd',
                        type=int,
                        choices=(1, 2, 3, 4),
                        default=None,
                        help='TESS CCD. Default is to run all CCDs.')
    parser.add_argument('--datasource',
                        type=str,
                        choices=('ffi', 'tpf'),
                        default='ffi',
                        help='Data source or cadence.')
    parser.add_argument(
        'input_folder',
        type=str,
        help=
        'Input directory. This directory should contain a TODO-file and corresponding lightcurves.',
        nargs='?',
        default=None)
    args = parser.parse_args()

    # Get input and output folder from environment variables:
    input_folder = args.input_folder
    if input_folder is None:
        input_folder = os.environ.get('TESSCORR_INPUT')
    if not input_folder:
        parser.error("Please specify an INPUT_FOLDER.")
    output_folder = os.environ.get('TESSCORR_OUTPUT',
                                   os.path.join(input_folder, 'lightcurves'))

    # Define MPI message tags
    tags = enum.IntEnum('tags', ('READY', 'DONE', 'EXIT', 'START'))

    # Initializations and preliminaries
    comm = MPI.COMM_WORLD  # get MPI communicator object
    size = comm.size  # total number of processes
    rank = comm.rank  # rank of this process
    status = MPI.Status()  # get MPI status object

    if rank == 0:
        try:
            with corrections.TaskManager(input_folder,
                                         cleanup=True,
                                         overwrite=args.overwrite,
                                         summary=os.path.join(
                                             output_folder,
                                             'summary_corr.json')) as tm:
                # Get list of tasks:
                numtasks = tm.get_number_tasks(camera=args.camera,
                                               ccd=args.ccd,
                                               datasource=args.datasource)
                tm.logger.info("%d tasks to be run", numtasks)

                # Start the master loop that will assign tasks
                # to the workers:
                num_workers = size - 1
                closed_workers = 0
                tm.logger.info("Master starting with %d workers", num_workers)
                while closed_workers < num_workers:
                    # Ask workers for information:
                    data = comm.recv(source=MPI.ANY_SOURCE,
                                     tag=MPI.ANY_TAG,
                                     status=status)
                    source = status.Get_source()
                    tag = status.Get_tag()

                    if tag == tags.DONE:
                        # The worker is done with a task
                        tm.logger.info("Got data from worker %d: %s", source,
                                       data)
                        tm.save_results(data)

                    if tag in (tags.DONE, tags.READY):
                        # Worker is ready, so send it a task
                        task = tm.get_task(camera=args.camera,
                                           ccd=args.ccd,
                                           datasource=args.datasource)
                        if task:
                            task_index = task['priority']
                            tm.start_task(task_index)
                            comm.send(task, dest=source, tag=tags.START)
                            tm.logger.info("Sending task %d to worker %d",
                                           task_index, source)
                        else:
                            comm.send(None, dest=source, tag=tags.EXIT)

                    elif tag == tags.EXIT:
                        # The worker has exited
                        tm.logger.info("Worker %d exited.", source)
                        closed_workers += 1

                    else:
                        # This should never happen, but just to
                        # make sure we don't run into an infinite loop:
                        raise Exception(
                            "Master received an unknown tag: '{0}'".format(
                                tag))

                tm.logger.info("Master finishing")

        except:
            # If something fails in the master
            print(traceback.format_exc().strip())
            comm.Abort(1)

    else:
        # Worker processes execute code below
        # Configure logging within photometry:
        formatter = logging.Formatter(
            '%(asctime)s - %(levelname)s - %(message)s')
        console = logging.StreamHandler()
        console.setFormatter(formatter)
        logger = logging.getLogger('corrections')
        logger.addHandler(console)
        logger.setLevel(logging.WARNING)

        # Get the class for the selected method:
        CorrClass = corrections.corrclass(args.method)

        try:
            with CorrClass(input_folder, plot=args.plot) as corr:

                # Send signal that we are ready for task:
                comm.send(None, dest=0, tag=tags.READY)

                while True:
                    # Receive a task from the master:
                    tic = default_timer()
                    task = comm.recv(source=0, tag=MPI.ANY_TAG, status=status)
                    tag = status.Get_tag()
                    toc = default_timer()

                    if tag == tags.START:
                        result = task.copy()

                        # Run the correction:
                        try:
                            result = corr.correct(task)
                        except:
                            # Something went wrong
                            error_msg = traceback.format_exc().strip()
                            result.update({
                                'status_corr': corrections.STATUS.ERROR,
                                'details': {
                                    'errors': error_msg
                                },
                            })

                        result.update({'worker_wait_time': toc - tic})
                        # Send the result back to the master:
                        comm.send(result, dest=0, tag=tags.DONE)

                        # Attempt some cleanup:
                        # TODO: Is this even needed?
                        del task, result

                    elif tag == tags.EXIT:
                        # We were told to EXIT, so lets do that
                        break

                    else:
                        # This should never happen, but just to
                        # make sure we don't run into an infinite loop:
                        raise Exception(
                            "Worker received an unknown tag: '{0}'".format(
                                tag))

        except:
            logger.exception("Something failed in worker")

        finally:
            comm.send(None, dest=0, tag=tags.EXIT)
Example #7
0
    if input_folder is None:
        test_folder = os.path.abspath(
            os.path.join(os.path.dirname(__file__), 'tests', 'input'))
        if args.test:
            input_folder = test_folder
        else:
            input_folder = os.environ.get('TESSCORR_INPUT', test_folder)

    output_folder = os.environ.get('TESSCORR_OUTPUT',
                                   os.path.join(input_folder, 'lightcurves'))

    logger.info("Loading input data from '%s'", input_folder)
    logger.info("Putting output data in '%s'", output_folder)

    # Get the class for the selected method:
    CorrClass = corrections.corrclass(args.method)

    # Initialize the corrector class:
    with CorrClass(input_folder, plot=args.plot) as corr:

        # Start the TaskManager:
        with corrections.TaskManager(input_folder) as tm:
            while True:
                if args.all:
                    task = tm.get_task()
                    if task is None: break
                elif args.starid is not None:
                    task = tm.get_task(starid=args.starid)
                elif args.random:
                    task = tm.get_random_task()
Example #8
0
def main():
    # Parse command line arguments:
    parser = argparse.ArgumentParser(
        description='Run TESS Corrections in parallel using MPI.')
    parser.add_argument('-d',
                        '--debug',
                        help='Print debug messages.',
                        action='store_true')
    parser.add_argument('-q',
                        '--quiet',
                        help='Only report warnings and errors.',
                        action='store_true')
    parser.add_argument('-m',
                        '--method',
                        help='Corrector method to use.',
                        default='cbv',
                        choices=('ensemble', 'cbv', 'kasoc_filter'))
    parser.add_argument('-o',
                        '--overwrite',
                        help='Overwrite existing results.',
                        action='store_true')
    parser.add_argument('-p',
                        '--plot',
                        help='Save plots when running.',
                        action='store_true')
    group = parser.add_argument_group('Filter which targets to process')
    group.add_argument('--sector', type=int, default=None, help='TESS Sector.')
    group.add_argument('--cadence',
                       type=CadenceType,
                       choices=('ffi', 1800, 600, 120, 20),
                       default=None,
                       help='Cadence. Default is to run all.')
    group.add_argument('--camera',
                       type=int,
                       choices=(1, 2, 3, 4),
                       default=None,
                       help='TESS Camera. Default is to run all cameras.')
    group.add_argument('--ccd',
                       type=int,
                       choices=(1, 2, 3, 4),
                       default=None,
                       help='TESS CCD. Default is to run all CCDs.')
    parser.add_argument(
        'input_folder',
        type=str,
        help=
        'Input directory. This directory should contain a TODO-file and corresponding lightcurves.',
        nargs='?',
        default=None)
    parser.add_argument('output_folder',
                        type=str,
                        help='Directory to save output in.',
                        nargs='?',
                        default=None)
    args = parser.parse_args()

    # Set logging level:
    logging_level = logging.INFO
    if args.quiet:
        logging_level = logging.WARNING
    elif args.debug:
        logging_level = logging.DEBUG

    # Get input and output folder from environment variables:
    input_folder = args.input_folder
    if input_folder is None:
        input_folder = os.environ.get('TESSCORR_INPUT')
    if not input_folder:
        parser.error("Please specify an INPUT_FOLDER.")

    output_folder = args.output_folder
    if output_folder is None:
        output_folder = os.environ.get(
            'TESSCORR_OUTPUT',
            os.path.join(os.path.dirname(input_folder), 'lightcurves'))

    # Define MPI message tags
    tags = enum.IntEnum('tags', ('INIT', 'READY', 'DONE', 'EXIT', 'START'))

    # Initializations and preliminaries
    comm = MPI.COMM_WORLD  # get MPI communicator object
    size = comm.size  # total number of processes
    rank = comm.rank  # rank of this process
    status = MPI.Status()  # get MPI status object

    if rank == 0:
        try:
            # Constraints on which targets to process:
            constraints = {
                'sector': args.sector,
                'cadence': args.cadence,
                'camera': args.camera,
                'ccd': args.ccd
            }

            # File path to write summary to:
            summary_file = os.path.join(output_folder,
                                        f'summary_corr_{args.method:s}.json')

            # Invoke the TaskManager to ensure that the input TODO-file has the correct columns
            # and indicies, which is automatically created by the TaskManager init function.
            with corrections.TaskManager(input_folder,
                                         cleanup=True,
                                         overwrite=args.overwrite,
                                         cleanup_constraints=constraints):
                pass

            # Signal that workers are free to initialize:
            comm.Barrier()  # Barrier 1

            # Wait for all workers to initialize:
            comm.Barrier()  # Barrier 2

            # Start TaskManager, which keeps track of the task that needs to be performed:
            with corrections.TaskManager(input_folder,
                                         overwrite=args.overwrite,
                                         cleanup_constraints=constraints,
                                         summary=summary_file) as tm:

                # Set level of TaskManager logger:
                tm.logger.setLevel(logging_level)

                # Get list of tasks:
                numtasks = tm.get_number_tasks(**constraints)
                tm.logger.info("%d tasks to be run", numtasks)

                # Start the master loop that will assign tasks
                # to the workers:
                num_workers = size - 1
                closed_workers = 0
                tm.logger.info("Master starting with %d workers", num_workers)
                while closed_workers < num_workers:
                    # Get information from worker:
                    data = comm.recv(source=MPI.ANY_SOURCE,
                                     tag=MPI.ANY_TAG,
                                     status=status)
                    source = status.Get_source()
                    tag = status.Get_tag()

                    if tag == tags.DONE:
                        # The worker is done with a task
                        tm.logger.debug("Got data from worker %d: %s", source,
                                        data)
                        tm.save_results(data)

                    if tag in (tags.DONE, tags.READY):
                        # Worker is ready for a new task, so send it a task
                        tasks = tm.get_task(**constraints, chunk=10)
                        if tasks:
                            tm.start_task(tasks)
                            tm.logger.debug("Sending %d tasks to worker %d",
                                            len(tasks), source)
                            comm.send(tasks, dest=source, tag=tags.START)
                        else:
                            comm.send(None, dest=source, tag=tags.EXIT)

                    elif tag == tags.EXIT:
                        # The worker has exited
                        tm.logger.info("Worker %d exited.", source)
                        closed_workers += 1

                    else:
                        # This should never happen, but just to
                        # make sure we don't run into an infinite loop:
                        raise RuntimeError(
                            f"Master received an unknown tag: '{tag}'")

                tm.logger.info("Master finishing")

        except:  # noqa: E722, pragma: no cover
            # If something fails in the master
            print(traceback.format_exc().strip())
            comm.Abort(1)

    else:
        # Worker processes execute code below
        # Configure logging within photometry:
        formatter = logging.Formatter(
            '%(asctime)s - %(levelname)s - %(message)s')
        console = logging.StreamHandler()
        console.setFormatter(formatter)
        logger = logging.getLogger('corrections')
        logger.addHandler(console)
        logger.setLevel(logging.WARNING)

        # Get the class for the selected method:
        CorrClass = corrections.corrclass(args.method)

        try:
            # Wait for signal that we are okay to initialize:
            comm.Barrier()  # Barrier 1

            # We can now safely initialize the corrector on the input file:
            with CorrClass(input_folder, plot=args.plot) as corr:

                # Wait for all workers do be done initializing:
                comm.Barrier()  # Barrier 2

                # Send signal that we are ready for task:
                comm.send(None, dest=0, tag=tags.READY)

                while True:
                    # Receive a task from the master:
                    tic = default_timer()
                    tasks = comm.recv(source=0, tag=MPI.ANY_TAG, status=status)
                    tag = status.Get_tag()
                    toc = default_timer()

                    if tag == tags.START:
                        # Make sure we can loop through tasks,
                        # even in the case we have only gotten one:
                        results = []
                        if not isinstance(tasks, (list, tuple)):
                            tasks = list(tasks)

                        # Loop through the tasks given to us:
                        for task in tasks:
                            result = task.copy()

                            # Run the correction:
                            try:
                                result = corr.correct(task)
                            except:  # noqa: E722
                                # Something went wrong
                                error_msg = traceback.format_exc().strip()
                                result.update({
                                    'status_corr':
                                    corrections.STATUS.ERROR,
                                    'details': {
                                        'errors': [error_msg]
                                    },
                                })

                            result.update({'worker_wait_time': toc - tic})
                            results.append(result)

                        # Send the result back to the master:
                        comm.send(results, dest=0, tag=tags.DONE)

                    elif tag == tags.EXIT:
                        # We were told to EXIT, so lets do that
                        break

                    else:
                        # This should never happen, but just to
                        # make sure we don't run into an infinite loop:
                        raise RuntimeError(
                            f"Worker received an unknown tag: '{tag}'")

        except:  # noqa: E722, pragma: no cover
            logger.exception("Something failed in worker")

        finally:
            comm.send(None, dest=0, tag=tags.EXIT)
Example #9
0
def main():
    # Parse command line arguments:
    parser = argparse.ArgumentParser(
        description='Run TESS Corrector pipeline on single star.')
    parser.add_argument('-m',
                        '--method',
                        help='Corrector method to use.',
                        default=None,
                        choices=('ensemble', 'cbv', 'kasoc_filter'))
    parser.add_argument('-d',
                        '--debug',
                        help='Print debug messages.',
                        action='store_true')
    parser.add_argument('-q',
                        '--quiet',
                        help='Only report warnings and errors.',
                        action='store_true')
    parser.add_argument('-p',
                        '--plot',
                        help='Save plots when running.',
                        action='store_true')
    parser.add_argument('-r',
                        '--random',
                        help='Run on random target from TODO-list.',
                        action='store_true')
    parser.add_argument(
        '-t',
        '--test',
        help='Use test data and ignore TESSCORR_INPUT environment variable.',
        action='store_true')
    parser.add_argument('-a',
                        '--all',
                        help='Run correction on all targets.',
                        action='store_true')
    parser.add_argument('-o',
                        '--overwrite',
                        help='Overwrite previous runs and start over.',
                        action='store_true')
    group = parser.add_argument_group('Filter which targets to process')
    group.add_argument('--priority',
                       type=int,
                       default=None,
                       help='Priority of target.')
    group.add_argument('--starid',
                       type=int,
                       default=None,
                       help='TIC identifier of target.')
    group.add_argument('--sector', type=int, default=None, help='TESS Sector.')
    group.add_argument('--cadence',
                       type=CadenceType,
                       choices=('ffi', 1800, 600, 120, 20),
                       default=None,
                       help='Cadence. Default is to run all.')
    group.add_argument('--camera',
                       type=int,
                       choices=(1, 2, 3, 4),
                       default=None,
                       help='TESS Camera. Default is to run all cameras.')
    group.add_argument('--ccd',
                       type=int,
                       choices=(1, 2, 3, 4),
                       default=None,
                       help='TESS CCD. Default is to run all CCDs.')
    parser.add_argument(
        'input_folder',
        type=str,
        help=
        'Input directory. This directory should contain a TODO-file and corresponding lightcurves.',
        nargs='?',
        default=None)
    parser.add_argument('output_folder',
                        type=str,
                        help='Directory to save output in.',
                        nargs='?',
                        default=None)
    args = parser.parse_args()

    # Make sure at least one setting is given:
    if not args.all and args.starid is None and args.priority is None and not args.random:
        parser.error(
            "Please select either a specific STARID, PRIORITY or RANDOM.")

    # Set logging level:
    logging_level = logging.INFO
    if args.quiet:
        logging_level = logging.WARNING
    elif args.debug:
        logging_level = logging.DEBUG

    # Setup logging:
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    console = logging.StreamHandler()
    console.setFormatter(formatter)
    logger = logging.getLogger(__name__)
    if not logger.hasHandlers():
        logger.addHandler(console)
    logger.setLevel(logging_level)
    logger_parent = logging.getLogger('corrections')
    if not logger_parent.hasHandlers():
        logger_parent.addHandler(console)
    logger_parent.setLevel(logging_level)

    # Get input and output folder from environment variables:
    input_folder = args.input_folder
    if input_folder is None:
        test_folder = os.path.abspath(
            os.path.join(os.path.dirname(__file__), 'tests', 'input'))
        if args.test:
            input_folder = test_folder
        else:
            input_folder = os.environ.get('TESSCORR_INPUT', test_folder)

    output_folder = args.output_folder
    if output_folder is None:
        output_folder = os.environ.get(
            'TESSCORR_OUTPUT',
            os.path.join(os.path.dirname(input_folder), 'lightcurves'))

    logger.info("Loading input data from '%s'", input_folder)
    logger.info("Putting output data in '%s'", output_folder)

    # Make sure the output directory exists:
    os.makedirs(output_folder, exist_ok=True)

    # Constraints on which targets to process:
    constraints = {
        'priority': args.priority,
        'starid': args.starid,
        'sector': args.sector,
        'cadence': args.cadence,
        'camera': args.camera,
        'ccd': args.ccd
    }

    # Get the class for the selected method:
    CorrClass = corrections.corrclass(args.method)

    # Start the TaskManager:
    with corrections.TaskManager(input_folder,
                                 overwrite=args.overwrite,
                                 cleanup_constraints=constraints) as tm:
        # Initialize the corrector class:
        with CorrClass(input_folder, plot=args.plot) as corr:
            while True:
                if args.random:
                    task = tm.get_random_task()
                else:
                    task = tm.get_task(**constraints)

                if task is None:
                    break

                # Run the correction:
                result = corr.correct(task, output_folder=output_folder)

                # Construct results to return to TaskManager:
                tm.save_results(result)

                if not args.all:
                    break