def process(self, inputs): checkpoint_path, input_file_pattern = inputs global_step = _get_step_from_checkpoint_path(checkpoint_path) # Create the input_fn. dataset_builder = kepler_light_curves.KeplerLightCurves( input_file_pattern, mode=tf.estimator.ModeKeys.PREDICT, config_overrides=self.dataset_overrides) tf.logging.info("Dataset config: %s", config_util.to_json(dataset_builder.config)) input_fn = estimator_util.create_input_fn(dataset_builder) # Create the estimator. estimator = estimator_util.create_estimator( astrowavenet_model.AstroWaveNet, self.hparams) # Generate predictions. for predictions in estimator.predict(input_fn, checkpoint_path=checkpoint_path): # Add global_step. predictions["global_step"] = global_step # Squeeze and un-pad the sequences. weights = np.squeeze(predictions["seq_weights"]) real_length = len(weights) while real_length > 0 and weights[real_length - 1] == 0: real_length -= 1 for name, value in predictions.items(): value = np.squeeze(predictions[name]) if value.shape: value = value[0:real_length] predictions[name] = value yield predictions
def main(argv): del argv # Unused. config = configdict.ConfigDict(configurations.get_config(FLAGS.config_name)) config_overrides = json.loads(FLAGS.config_overrides) for key in config_overrides: if key not in ["dataset", "hparams"]: raise ValueError("Unrecognized config override: {}".format(key)) config.hparams.update(config_overrides.get("hparams", {})) # Log configs. configs_json = [ ("config_overrides", config_util.to_json(config_overrides)), ("config", config_util.to_json(config)), ] for config_name, config_json in configs_json: tf.logging.info("%s: %s", config_name, config_json) # Create the estimator. run_config = _create_run_config() estimator = estimator_util.create_estimator( astrowavenet_model.AstroWaveNet, config.hparams, run_config, FLAGS.model_dir, FLAGS.eval_batch_size) if FLAGS.schedule in ["train", "train_and_eval"]: # Save configs. tf.gfile.MakeDirs(FLAGS.model_dir) for config_name, config_json in configs_json: filename = os.path.join(FLAGS.model_dir, "{}.json".format(config_name)) with tf.gfile.Open(filename, "w") as f: f.write(config_json) train_input_fn = _create_input_fn(tf.estimator.ModeKeys.TRAIN, config_overrides.get("dataset")) train_hooks = [] if FLAGS.schedule == "train": estimator.train( train_input_fn, hooks=train_hooks, max_steps=FLAGS.train_steps) else: assert FLAGS.schedule == "train_and_eval" eval_args = _create_eval_args(config_overrides.get("dataset")) for _ in estimator_runner.continuous_train_and_eval( estimator=estimator, train_input_fn=train_input_fn, eval_args=eval_args, local_eval_frequency=FLAGS.local_eval_frequency, train_hooks=train_hooks, train_steps=FLAGS.train_steps): # continuous_train_and_eval() yields evaluation metrics after each # FLAGS.local_eval_frequency. It also saves and logs them, so we don't # do anything here. pass else: assert FLAGS.schedule == "continuous_eval" eval_args = _create_eval_args(config_overrides.get("dataset")) for _ in estimator_runner.continuous_eval( estimator=estimator, eval_args=eval_args, train_steps=FLAGS.train_steps): # continuous_train_and_eval() yields evaluation metrics after each # checkpoint. It also saves and logs them, so we don't do anything here. pass
def _create_input_fn(mode, config_overrides=None): """Creates an Estimator input_fn.""" builder = _create_dataset_builder(mode, config_overrides) tf.logging.info("Dataset config for mode '%s': %s", mode, config_util.to_json(builder.config)) return estimator_util.create_input_fn(builder)