def main(_): model_class = models.get_model_class(FLAGS.model) # Look up the model configuration. assert (FLAGS.config_name is None) != (FLAGS.config_json is None), ( "Exactly one of --config_name or --config_json is required.") config = ( models.get_model_config(FLAGS.model, FLAGS.config_name) if FLAGS.config_name else config_util.parse_json(FLAGS.config_json)) config = configdict.ConfigDict(config) # Create the estimator. estimator = estimator_util.create_estimator( model_class, config.hparams, model_dir=FLAGS.model_dir) # Read and process the input features. features = _process_tce(config.inputs.features) # Create an input function. def input_fn(): return tf.data.Dataset.from_tensors({"time_series_features": features}) # Generate the predictions. for predictions in estimator.predict(input_fn): assert len(predictions) == 1 print("Prediction:", predictions[0])
def main(_): model_class = models.get_model_class(FLAGS.model) # Look up the model configuration. assert (FLAGS.config_name is None) != (FLAGS.config_json is None), ( "Exactly one of --config_name or --config_json is required.") config = ( models.get_model_config(FLAGS.model, FLAGS.config_name) if FLAGS.config_name else config_util.parse_json(FLAGS.config_json)) config = configdict.ConfigDict(config) # Create the estimator. estimator = estimator_util.create_estimator( model_class, config.hparams, model_dir=FLAGS.model_dir) # Create an input function that reads the evaluation dataset. input_fn = estimator_util.create_input_fn( file_pattern=FLAGS.eval_files, input_config=config.inputs, mode=tf.estimator.ModeKeys.EVAL) # Run evaluation. This will log the result to stderr and also write a summary # file in the model_dir. eval_steps = None # Evaluate over all examples in the file. eval_args = {FLAGS.eval_name: (input_fn, eval_steps)} estimator_runner.evaluate(estimator, eval_args)
def main(_): model_class = models.get_model_class(FLAGS.model) # Look up the model configuration. assert (FLAGS.config_name is None) != (FLAGS.config_json is None), ( "Exactly one of --config_name or --config_json is required.") config = ( models.get_model_config(FLAGS.model, FLAGS.config_name) if FLAGS.config_name else config_util.parse_json(FLAGS.config_json)) config = configdict.ConfigDict(config) config_util.log_and_save_config(config, FLAGS.model_dir) # Create the estimator. run_config = tf.estimator.RunConfig(keep_checkpoint_max=1) estimator = estimator_util.create_estimator(model_class, config.hparams, run_config, FLAGS.model_dir) # Create an input function that reads the training dataset. We iterate through # the dataset once at a time if we are alternating with evaluation, otherwise # we iterate infinitely. train_input_fn = estimator_util.create_input_fn( file_pattern=FLAGS.train_files, input_config=config.inputs, mode=tf.estimator.ModeKeys.TRAIN, shuffle_values_buffer=FLAGS.shuffle_buffer_size, repeat=1 if FLAGS.eval_files else None) if not FLAGS.eval_files: estimator.train(train_input_fn, max_steps=FLAGS.train_steps) else: eval_input_fn = estimator_util.create_input_fn( file_pattern=FLAGS.eval_files, input_config=config.inputs, mode=tf.estimator.ModeKeys.EVAL) eval_args = { "val": (eval_input_fn, None) # eval_name: (input_fn, eval_steps) } for _ in estimator_runner.continuous_train_and_eval( estimator=estimator, train_input_fn=train_input_fn, eval_args=eval_args, train_steps=FLAGS.train_steps): # continuous_train_and_eval() yields evaluation metrics after each # training epoch. We don't do anything here. pass
def __init__(self, model_dir, checkpoint_filename=None, apply_relu_to_embeddings=False, align_to_predictions=False, interpolate_missing_time=False): """Initializes the DoFn. Args: model_dir: Directory containing AstroWaveNet checkpoints. checkpoint_filename: Optional name of the AstroWaveNet filename to use. If not specified, the most recent checkpoint it used. apply_relu_to_embeddings: Whether to pass the embeddings through a ReLu function. align_to_predictions: Whether to align embeddings with the time value that the embedding vector was used to predict (as opposed to the most recent time value included in the receptive field). interpolate_missing_time: Whether to interpolate missing time values and return their embeddings. Otherwise, missing time values are removed. """ config = config_util.parse_json(os.path.join(model_dir, "config.json")) config = configdict.ConfigDict(config) if checkpoint_filename: checkpoint_file = os.path.join(model_dir, checkpoint_filename) else: checkpoint_file = tf.train.latest_checkpoint(model_dir) if not checkpoint_file: raise ValueError( "No checkpoint file found in: {}".format(model_dir)) self.config = config self.checkpoint_file = checkpoint_file self.apply_relu_to_embeddings = apply_relu_to_embeddings self.align_to_predictions = align_to_predictions self.interpolate_missing_time = interpolate_missing_time
def pipeline(root): """Beam pipeline for running transit searches with Box Least Squares.""" # Parse config. config = configdict.ConfigDict(config_util.parse_json(FLAGS.config_json)) # Choose periods. period_min = config.period_min period_max = config.period_max period_sampling_args = config.period_sampling_args or {} if config.period_sampling_method == "andrew": choose_periods = _choose_periods_andrew elif config.period_sampling_method == "uniform_frequency": choose_periods = _choose_periods_uniform_freq elif config.period_sampling_method == "logarithmic": choose_periods = np.geomspace elif config.period_sampling_method == "uniform_period": choose_periods = np.linspace else: raise ValueError("Unrecognized period_sampling_method: {}".format( config.period_sampling_method)) all_periods = choose_periods(period_min, period_max, **period_sampling_args) # Choose nbins. nbins_args = config.nbins_args or {} all_nbins = [] for period in all_periods: if config.nbins_method == "andrew": all_nbins.append(_choose_nbins_andrew(period, **nbins_args)) elif config.nbins_method == "constant": all_nbins.append(nbins_args["num"]) else: raise ValueError("Unrecognized nbins_method: {}".format( config.nbins_method)) # Write the config. config_json = config.to_json(indent=2) root | beam.Create([config_json]) | "write_config" >> beam.io.WriteToText( os.path.join(FLAGS.output_dir, "config.json"), num_shards=1, shard_name_template="") # Initialize DoFns. # TODO(shallue): I think I can pass these as kwargs into ParDo. read_light_curve = light_curve_fns.ReadLightCurveDoFn( FLAGS.kepler_data_dir, injected_group=config.injected_group, scramble_type=config.scramble_type, invert_light_curves=config.invert_light_curves) # process_light_curve_for_astronet = light_curve_fns.ProcessLightCurveDoFn( # gap_width=config.predictions.gap_width, # normalize_method=config.predictions.normalize_method, # normalize_args=config.predictions.normalize_args, # upward_outlier_sigma_cut=config.predictions.upward_outlier_sigma_cut, # output_name="light_curve_for_predictions") generate_periodogram = bls_fns.GeneratePeriodogramDoFn( all_periods, all_nbins, config.weight_min_factor, config.duration_density_min, config.duration_min_days, config.duration_density_max, config.duration_min_fraction) compute_top_results = bls_fns.TopResultsDoFn(config.score_methods, config.ignore_negative_depth) get_top_result = bls_fns.GetTopResultDoFn(config.top_detection_score_method) fit_transit_params = transit_fns.FitTransitParametersDoFn() count_transits = transit_fns.CountTransitsDoFn( config.complete_transit_fraction) # make_predictions = prediction_fns.MakePredictionsDoFn( # FLAGS.astronet_model, FLAGS.astronet_config_name, # FLAGS.astronet_config_json, FLAGS.astronet_model_dir) postprocess_for_next_detection = bls_fns.PostProcessForNextDetectionDoFn( score_threshold=config.top_detection_score_threshold) # Read Kepler IDs. # Output: PCollection({"kepler_id"}) kep_ids = ( root | "read_kep_ids" >> beam.io.textio.ReadFromText( FLAGS.input_path, coder=kepler_id.KeplerIdCoder()) | "create_input_dicts" >> beam.Map(lambda kep_id: {"kepler_id": kep_id.value})) # Read light curves. # Input: PCollection({"kepler_id"}) # Output: PCollection({"kepler_id", "raw_light_curve"}) raw_light_curves = ( kep_ids | "read_light_curves" >> beam.ParDo(read_light_curve)) # | "process_light_curve_for_astronet" >> # beam.ParDo(process_light_curve_for_astronet)) if FLAGS.save_intermediate_output: _write_output( raw_light_curves, output_name="raw-light-curves", value_name="raw_light_curve", value_coder=beam.coders.ProtoCoder(light_curve_pb2.RawLightCurve)) # csv_lines = [] for planet_num in range(config.max_detections): if planet_num > config.clip_downward_outliers_after_planet_num: downward_outlier_sigma_cut = config.downward_outlier_sigma_cut else: downward_outlier_sigma_cut = None process_light_curve = light_curve_fns.ProcessLightCurveDoFn( gap_width=config.gap_width, normalize_method=config.normalize_method, normalize_args=config.normalize_args, upward_outlier_sigma_cut=config.upward_outlier_sigma_cut, downward_outlier_sigma_cut=downward_outlier_sigma_cut, remove_events_width_factor=config.remove_events_width_factor) # Process light curves. # Input: PCollection({ # "kepler_id", # "raw_light_curve", # "events_to_remove", (optional) # }) # Output: PCollection({ # "kepler_id", # "raw_light_curve", # "light_curve", # }) light_curves = ( raw_light_curves | "process_light_curves-%d" % planet_num >> beam.ParDo(process_light_curve)) # Generate periodograms. # Input: PCollection({ # "kepler_id", # "raw_light_curve", # "light_curve", # }) # Output: PCollection({ # "kepler_id", # "raw_light_curve", # "light_curve", # "periodogram", # }) periodograms = ( light_curves | "generate_periodogram-%d" % planet_num >> beam.ParDo(generate_periodogram)) # Compute top results. # Input: PCollection({ # "kepler_id", # "raw_light_curve", # "light_curve", # "periodogram", # }) # Output: PCollection({ # "kepler_id", # "raw_light_curve", # "light_curve", # "periodogram", # "top_results", # "top_result", # }) top_results = ( periodograms | "compute_top_results-%d" % planet_num >> beam.ParDo(compute_top_results) | "get_top_result-%d" % planet_num >> beam.ParDo(get_top_result) | "count_transits-%d" % planet_num >> beam.ParDo(count_transits) | "fit_transit_params-%d" % planet_num >> beam.ParDo(fit_transit_params)) # | "make_predictions-%d" % planet_num >> beam.ParDo(make_predictions)) # csv_lines.append(top_results # | "extract_csv_%d" % planet_num >> beam.ParDo( # prediction_fns.ToCsvDoFn(planet_num=planet_num))) # Write the outputs. _write_output( top_results, output_name="top-results-%d" % planet_num, value_name="top_results", value_coder=beam.coders.ProtoCoder(bls_pb2.TopResults)) # Write the outputs. _write_output( top_results, output_name="scored-result-with-transit-fit-%d" % planet_num, value_name="top_result", value_coder=beam.coders.ProtoCoder(bls_pb2.ScoredResult)) if FLAGS.save_intermediate_output: _write_output( light_curves, output_name="light-curves-%d" % planet_num, value_name="light_curve", value_coder=beam.coders.ProtoCoder(light_curve_pb2.LightCurve)) _write_output( periodograms, output_name="periodograms-%d" % planet_num, value_name="periodogram", value_coder=beam.coders.ProtoCoder(bls_pb2.Periodogram)) # Process light curves for the next round. if planet_num < config.max_detections - 1: # Extract detected events. # Input: PCollection({ # "kepler_id", # "raw_light_curve", # "light_curve", # "periodogram", # "top_results", # }) # Output: PCollection({ # "kepler_id", # "raw_light_curve", # "events_to_remove", # }) raw_light_curves = ( top_results | "postprocess-%d" % planet_num >> beam.ParDo(postprocess_for_next_detection))