Ejemplo n.º 1
0
    def process(self, inputs):
        checkpoint_path, input_file_pattern = inputs
        global_step = _get_step_from_checkpoint_path(checkpoint_path)

        # Create the input_fn.
        dataset_builder = kepler_light_curves.KeplerLightCurves(
            input_file_pattern,
            mode=tf.estimator.ModeKeys.PREDICT,
            config_overrides=self.dataset_overrides)
        tf.logging.info("Dataset config: %s",
                        config_util.to_json(dataset_builder.config))
        input_fn = estimator_util.create_input_fn(dataset_builder)

        # Create the estimator.
        estimator = estimator_util.create_estimator(
            astrowavenet_model.AstroWaveNet, self.hparams)

        # Generate predictions.
        for predictions in estimator.predict(input_fn,
                                             checkpoint_path=checkpoint_path):
            # Add global_step.
            predictions["global_step"] = global_step

            # Squeeze and un-pad the sequences.
            weights = np.squeeze(predictions["seq_weights"])
            real_length = len(weights)
            while real_length > 0 and weights[real_length - 1] == 0:
                real_length -= 1
            for name, value in predictions.items():
                value = np.squeeze(predictions[name])
                if value.shape:
                    value = value[0:real_length]
                    predictions[name] = value

            yield predictions
def main(argv):
  del argv  # Unused.

  config = configdict.ConfigDict(configurations.get_config(FLAGS.config_name))
  config_overrides = json.loads(FLAGS.config_overrides)
  for key in config_overrides:
    if key not in ["dataset", "hparams"]:
      raise ValueError("Unrecognized config override: {}".format(key))
  config.hparams.update(config_overrides.get("hparams", {}))

  # Log configs.
  configs_json = [
      ("config_overrides", config_util.to_json(config_overrides)),
      ("config", config_util.to_json(config)),
  ]
  for config_name, config_json in configs_json:
    tf.logging.info("%s: %s", config_name, config_json)

  # Create the estimator.
  run_config = _create_run_config()
  estimator = estimator_util.create_estimator(
      astrowavenet_model.AstroWaveNet, config.hparams, run_config,
      FLAGS.model_dir, FLAGS.eval_batch_size)

  if FLAGS.schedule in ["train", "train_and_eval"]:
    # Save configs.
    tf.gfile.MakeDirs(FLAGS.model_dir)
    for config_name, config_json in configs_json:
      filename = os.path.join(FLAGS.model_dir, "{}.json".format(config_name))
      with tf.gfile.Open(filename, "w") as f:
        f.write(config_json)

    train_input_fn = _create_input_fn(tf.estimator.ModeKeys.TRAIN,
                                      config_overrides.get("dataset"))

    train_hooks = []
    if FLAGS.schedule == "train":
      estimator.train(
          train_input_fn, hooks=train_hooks, max_steps=FLAGS.train_steps)
    else:
      assert FLAGS.schedule == "train_and_eval"

      eval_args = _create_eval_args(config_overrides.get("dataset"))
      for _ in estimator_runner.continuous_train_and_eval(
          estimator=estimator,
          train_input_fn=train_input_fn,
          eval_args=eval_args,
          local_eval_frequency=FLAGS.local_eval_frequency,
          train_hooks=train_hooks,
          train_steps=FLAGS.train_steps):
        # continuous_train_and_eval() yields evaluation metrics after each
        # FLAGS.local_eval_frequency. It also saves and logs them, so we don't
        # do anything here.
        pass

  else:
    assert FLAGS.schedule == "continuous_eval"

    eval_args = _create_eval_args(config_overrides.get("dataset"))
    for _ in estimator_runner.continuous_eval(
        estimator=estimator, eval_args=eval_args,
        train_steps=FLAGS.train_steps):
      # continuous_train_and_eval() yields evaluation metrics after each
      # checkpoint. It also saves and logs them, so we don't do anything here.
      pass
def _create_input_fn(mode, config_overrides=None):
  """Creates an Estimator input_fn."""
  builder = _create_dataset_builder(mode, config_overrides)
  tf.logging.info("Dataset config for mode '%s': %s", mode,
                  config_util.to_json(builder.config))
  return estimator_util.create_input_fn(builder)