Exemplo n.º 1
0
def main(_):
  model_class = models.get_model_class(FLAGS.model)

  # Look up the model configuration.
  assert (FLAGS.config_name is None) != (FLAGS.config_json is None), (
      "Exactly one of --config_name or --config_json is required.")
  config = (
      models.get_model_config(FLAGS.model, FLAGS.config_name)
      if FLAGS.config_name else config_util.parse_json(FLAGS.config_json))
  config = configdict.ConfigDict(config)

  # Create the estimator.
  estimator = estimator_util.create_estimator(
      model_class, config.hparams, model_dir=FLAGS.model_dir)

  # Read and process the input features.
  features = _process_tce(config.inputs.features)

  # Create an input function.
  def input_fn():
    return tf.data.Dataset.from_tensors({"time_series_features": features})

  # Generate the predictions.
  for predictions in estimator.predict(input_fn):
    assert len(predictions) == 1
    print("Prediction:", predictions[0])
def main(_):
  model_class = models.get_model_class(FLAGS.model)

  # Look up the model configuration.
  assert (FLAGS.config_name is None) != (FLAGS.config_json is None), (
      "Exactly one of --config_name or --config_json is required.")
  config = (
      models.get_model_config(FLAGS.model, FLAGS.config_name)
      if FLAGS.config_name else config_util.parse_json(FLAGS.config_json))

  config = configdict.ConfigDict(config)

  # Create the estimator.
  estimator = estimator_util.create_estimator(
      model_class, config.hparams, model_dir=FLAGS.model_dir)

  # Create an input function that reads the evaluation dataset.
  input_fn = estimator_util.create_input_fn(
      file_pattern=FLAGS.eval_files,
      input_config=config.inputs,
      mode=tf.estimator.ModeKeys.EVAL)

  # Run evaluation. This will log the result to stderr and also write a summary
  # file in the model_dir.
  eval_steps = None  # Evaluate over all examples in the file.
  eval_args = {FLAGS.eval_name: (input_fn, eval_steps)}
  estimator_runner.evaluate(estimator, eval_args)
def main(_):
  model_class = models.get_model_class(FLAGS.model)

  # Look up the model configuration.
  assert (FLAGS.config_name is None) != (FLAGS.config_json is None), (
      "Exactly one of --config_name or --config_json is required.")
  config = (
      models.get_model_config(FLAGS.model, FLAGS.config_name)
      if FLAGS.config_name else config_util.parse_json(FLAGS.config_json))

  config = configdict.ConfigDict(config)
  config_util.log_and_save_config(config, FLAGS.model_dir)

  # Create the estimator.
  run_config = tf.estimator.RunConfig(keep_checkpoint_max=1)
  estimator = estimator_util.create_estimator(model_class, config.hparams,
                                              run_config, FLAGS.model_dir)

  # Create an input function that reads the training dataset. We iterate through
  # the dataset once at a time if we are alternating with evaluation, otherwise
  # we iterate infinitely.
  train_input_fn = estimator_util.create_input_fn(
      file_pattern=FLAGS.train_files,
      input_config=config.inputs,
      mode=tf.estimator.ModeKeys.TRAIN,
      shuffle_values_buffer=FLAGS.shuffle_buffer_size,
      repeat=1 if FLAGS.eval_files else None)

  if not FLAGS.eval_files:
    estimator.train(train_input_fn, max_steps=FLAGS.train_steps)
  else:
    eval_input_fn = estimator_util.create_input_fn(
        file_pattern=FLAGS.eval_files,
        input_config=config.inputs,
        mode=tf.estimator.ModeKeys.EVAL)
    eval_args = {
        "val": (eval_input_fn, None)  # eval_name: (input_fn, eval_steps)
    }

    for _ in estimator_runner.continuous_train_and_eval(
        estimator=estimator,
        train_input_fn=train_input_fn,
        eval_args=eval_args,
        train_steps=FLAGS.train_steps):
      # continuous_train_and_eval() yields evaluation metrics after each
      # training epoch. We don't do anything here.
      pass
Exemplo n.º 4
0
    def __init__(self,
                 model_dir,
                 checkpoint_filename=None,
                 apply_relu_to_embeddings=False,
                 align_to_predictions=False,
                 interpolate_missing_time=False):
        """Initializes the DoFn.

    Args:
      model_dir: Directory containing AstroWaveNet checkpoints.
      checkpoint_filename: Optional name of the AstroWaveNet filename to use. If
        not specified, the most recent checkpoint it used.
      apply_relu_to_embeddings: Whether to pass the embeddings through a ReLu
        function.
      align_to_predictions: Whether to align embeddings with the time value that
        the embedding vector was used to predict (as opposed to the most recent
        time value included in the receptive field).
      interpolate_missing_time: Whether to interpolate missing time values and
        return their embeddings. Otherwise, missing time values are removed.
    """
        config = config_util.parse_json(os.path.join(model_dir, "config.json"))
        config = configdict.ConfigDict(config)

        if checkpoint_filename:
            checkpoint_file = os.path.join(model_dir, checkpoint_filename)
        else:
            checkpoint_file = tf.train.latest_checkpoint(model_dir)
            if not checkpoint_file:
                raise ValueError(
                    "No checkpoint file found in: {}".format(model_dir))

        self.config = config
        self.checkpoint_file = checkpoint_file
        self.apply_relu_to_embeddings = apply_relu_to_embeddings
        self.align_to_predictions = align_to_predictions
        self.interpolate_missing_time = interpolate_missing_time
Exemplo n.º 5
0
  def pipeline(root):
    """Beam pipeline for running transit searches with Box Least Squares."""
    # Parse config.
    config = configdict.ConfigDict(config_util.parse_json(FLAGS.config_json))

    # Choose periods.
    period_min = config.period_min
    period_max = config.period_max
    period_sampling_args = config.period_sampling_args or {}
    if config.period_sampling_method == "andrew":
      choose_periods = _choose_periods_andrew
    elif config.period_sampling_method == "uniform_frequency":
      choose_periods = _choose_periods_uniform_freq
    elif config.period_sampling_method == "logarithmic":
      choose_periods = np.geomspace
    elif config.period_sampling_method == "uniform_period":
      choose_periods = np.linspace
    else:
      raise ValueError("Unrecognized period_sampling_method: {}".format(
          config.period_sampling_method))

    all_periods = choose_periods(period_min, period_max, **period_sampling_args)

    # Choose nbins.
    nbins_args = config.nbins_args or {}
    all_nbins = []
    for period in all_periods:
      if config.nbins_method == "andrew":
        all_nbins.append(_choose_nbins_andrew(period, **nbins_args))
      elif config.nbins_method == "constant":
        all_nbins.append(nbins_args["num"])
      else:
        raise ValueError("Unrecognized nbins_method: {}".format(
            config.nbins_method))

    # Write the config.
    config_json = config.to_json(indent=2)
    root | beam.Create([config_json]) | "write_config" >> beam.io.WriteToText(
        os.path.join(FLAGS.output_dir, "config.json"),
        num_shards=1,
        shard_name_template="")

    # Initialize DoFns.
    # TODO(shallue): I think I can pass these as kwargs into ParDo.
    read_light_curve = light_curve_fns.ReadLightCurveDoFn(
        FLAGS.kepler_data_dir,
        injected_group=config.injected_group,
        scramble_type=config.scramble_type,
        invert_light_curves=config.invert_light_curves)

    # process_light_curve_for_astronet = light_curve_fns.ProcessLightCurveDoFn(
    #     gap_width=config.predictions.gap_width,
    #     normalize_method=config.predictions.normalize_method,
    #     normalize_args=config.predictions.normalize_args,
    #     upward_outlier_sigma_cut=config.predictions.upward_outlier_sigma_cut,
    #     output_name="light_curve_for_predictions")

    generate_periodogram = bls_fns.GeneratePeriodogramDoFn(
        all_periods, all_nbins, config.weight_min_factor,
        config.duration_density_min, config.duration_min_days,
        config.duration_density_max, config.duration_min_fraction)

    compute_top_results = bls_fns.TopResultsDoFn(config.score_methods,
                                                 config.ignore_negative_depth)

    get_top_result = bls_fns.GetTopResultDoFn(config.top_detection_score_method)

    fit_transit_params = transit_fns.FitTransitParametersDoFn()

    count_transits = transit_fns.CountTransitsDoFn(
        config.complete_transit_fraction)

    # make_predictions = prediction_fns.MakePredictionsDoFn(
    #     FLAGS.astronet_model, FLAGS.astronet_config_name,
    #     FLAGS.astronet_config_json, FLAGS.astronet_model_dir)

    postprocess_for_next_detection = bls_fns.PostProcessForNextDetectionDoFn(
        score_threshold=config.top_detection_score_threshold)

    # Read Kepler IDs.
    # Output: PCollection({"kepler_id"})
    kep_ids = (
        root
        | "read_kep_ids" >> beam.io.textio.ReadFromText(
            FLAGS.input_path, coder=kepler_id.KeplerIdCoder())
        | "create_input_dicts" >>
        beam.Map(lambda kep_id: {"kepler_id": kep_id.value}))

    # Read light curves.
    # Input: PCollection({"kepler_id"})
    # Output: PCollection({"kepler_id", "raw_light_curve"})
    raw_light_curves = (
        kep_ids
        | "read_light_curves" >> beam.ParDo(read_light_curve))
    # | "process_light_curve_for_astronet" >>
    # beam.ParDo(process_light_curve_for_astronet))

    if FLAGS.save_intermediate_output:
      _write_output(
          raw_light_curves,
          output_name="raw-light-curves",
          value_name="raw_light_curve",
          value_coder=beam.coders.ProtoCoder(light_curve_pb2.RawLightCurve))

    # csv_lines = []
    for planet_num in range(config.max_detections):
      if planet_num > config.clip_downward_outliers_after_planet_num:
        downward_outlier_sigma_cut = config.downward_outlier_sigma_cut
      else:
        downward_outlier_sigma_cut = None

      process_light_curve = light_curve_fns.ProcessLightCurveDoFn(
          gap_width=config.gap_width,
          normalize_method=config.normalize_method,
          normalize_args=config.normalize_args,
          upward_outlier_sigma_cut=config.upward_outlier_sigma_cut,
          downward_outlier_sigma_cut=downward_outlier_sigma_cut,
          remove_events_width_factor=config.remove_events_width_factor)

      # Process light curves.
      # Input: PCollection({
      #   "kepler_id",
      #   "raw_light_curve",
      #   "events_to_remove",  (optional)
      #  })
      # Output: PCollection({
      #   "kepler_id",
      #   "raw_light_curve",
      #   "light_curve",
      # })
      light_curves = (
          raw_light_curves | "process_light_curves-%d" % planet_num >>
          beam.ParDo(process_light_curve))

      # Generate periodograms.
      # Input: PCollection({
      #   "kepler_id",
      #   "raw_light_curve",
      #   "light_curve",
      #  })
      # Output: PCollection({
      #   "kepler_id",
      #   "raw_light_curve",
      #   "light_curve",
      #   "periodogram",
      # })
      periodograms = (
          light_curves | "generate_periodogram-%d" % planet_num >>
          beam.ParDo(generate_periodogram))

      # Compute top results.
      # Input: PCollection({
      #   "kepler_id",
      #   "raw_light_curve",
      #   "light_curve",
      #   "periodogram",
      # })
      # Output: PCollection({
      #   "kepler_id",
      #   "raw_light_curve",
      #   "light_curve",
      #   "periodogram",
      #   "top_results",
      #   "top_result",
      # })
      top_results = (
          periodograms
          | "compute_top_results-%d" % planet_num >>
          beam.ParDo(compute_top_results)
          | "get_top_result-%d" % planet_num >> beam.ParDo(get_top_result)
          | "count_transits-%d" % planet_num >> beam.ParDo(count_transits)
          | "fit_transit_params-%d" % planet_num >>
          beam.ParDo(fit_transit_params))
      # | "make_predictions-%d" % planet_num >> beam.ParDo(make_predictions))

      # csv_lines.append(top_results
      #                 | "extract_csv_%d" % planet_num >> beam.ParDo(
      #                     prediction_fns.ToCsvDoFn(planet_num=planet_num)))

      # Write the outputs.
      _write_output(
          top_results,
          output_name="top-results-%d" % planet_num,
          value_name="top_results",
          value_coder=beam.coders.ProtoCoder(bls_pb2.TopResults))
      # Write the outputs.
      _write_output(
          top_results,
          output_name="scored-result-with-transit-fit-%d" % planet_num,
          value_name="top_result",
          value_coder=beam.coders.ProtoCoder(bls_pb2.ScoredResult))
      if FLAGS.save_intermediate_output:
        _write_output(
            light_curves,
            output_name="light-curves-%d" % planet_num,
            value_name="light_curve",
            value_coder=beam.coders.ProtoCoder(light_curve_pb2.LightCurve))
        _write_output(
            periodograms,
            output_name="periodograms-%d" % planet_num,
            value_name="periodogram",
            value_coder=beam.coders.ProtoCoder(bls_pb2.Periodogram))

      # Process light curves for the next round.
      if planet_num < config.max_detections - 1:
        # Extract detected events.
        # Input: PCollection({
        #   "kepler_id",
        #   "raw_light_curve",
        #   "light_curve",
        #   "periodogram",
        #   "top_results",
        # })
        # Output: PCollection({
        #   "kepler_id",
        #   "raw_light_curve",
        #   "events_to_remove",
        # })
        raw_light_curves = (
            top_results
            | "postprocess-%d" % planet_num >>
            beam.ParDo(postprocess_for_next_detection))