Beispiel #1
0
def get_model_config(model_name, config_name):
  """Looks up a model configuration by name.

  Args:
    model_name: Name of the model class.
    config_name: Name of a configuration-builder function from the model's
        configurations module.

  Returns:
    model_class: The requested model class.
    config: The requested configuration.

  Raises:
    ValueError: If model_name or config_name is unrecognized.
  """
  if model_name not in _MODELS:
    raise ValueError("Unrecognized model name: %s" % model_name)

  config_module = _MODELS[model_name][1]
  try:
    config = getattr(config_module, config_name)()
    config = configdict.ConfigDict(config)
    return config
  except AttributeError:
    raise ValueError("Config name '%s' not found in configuration module: %s" %
                     (config_name, config_module.__name__))
def main(_):
    # Look up the model class.
    model_class = models.get_model_class(FLAGS.model)

    # Look up the model configuration.
    if (FLAGS.config_name is None) == (FLAGS.config_json is None):
        raise ValueError("Exactly one of config_name or config_json is required.")
    config = (
        models.get_model_config(FLAGS.model, FLAGS.config_name)
        if FLAGS.config_name else config_util.parse_json(FLAGS.config_json))
    config = configdict.ConfigDict(config)

    if FLAGS.average:
        model_dirs = glob.glob(FLAGS.model_dir+'/*')
    else:
        model_dirs = [FLAGS.model_dir]

    y_pred = defaultdict(list)
    for model_dir in model_dirs:
        # append a new prediction to y_pred for each TCE every time
        y_pred = predict(model_dir, config, model_class, y_pred)

    y_pred_average = []
    for tce in y_pred:
        average_pred = np.mean(y_pred[tce])
        y_pred_average.append([tce, average_pred])

    np.savetxt('prediction_'+FLAGS.suffix+'.txt', np.array(y_pred_average), fmt=['%d', '%4.3f'])
Beispiel #3
0
    def setUp(self):
        super(BuildDatasetTest, self).setUp()

        # The test dataset contains 10 tensorflow.Example protocol buffers. The i-th
        # Example contains the following features:
        #   global_view = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]
        #   local_view = [0.0, 1.0, 2.0, 3.0]
        #   aux_feature = 100 + i
        #   label_str = "PC" if i % 3 == 0 else "AFP" if i % 3 == 1 else "NTP"
        self._file_pattern = os.path.join(FLAGS.test_srcdir,
                                          _TEST_TFRECORD_FILE)

        self._input_config = configdict.ConfigDict({
            "features": {
                "global_view": {
                    "is_time_series": True,
                    "length": 8
                },
                "local_view": {
                    "is_time_series": True,
                    "length": 4
                },
                "aux_feature": {
                    "is_time_series": False,
                    "length": 1
                }
            }
        })
Beispiel #4
0
def main(_):
  model_class = models.get_model_class(FLAGS.model)

  # Look up the model configuration.
  assert (FLAGS.config_name is None) != (FLAGS.config_json is None), (
      "Exactly one of --config_name or --config_json is required.")
  config = (
      models.get_model_config(FLAGS.model, FLAGS.config_name)
      if FLAGS.config_name else config_util.parse_json(FLAGS.config_json))

  config = configdict.ConfigDict(config)

  # Create the estimator.
  estimator = estimator_util.create_estimator(
      model_class, config.hparams, model_dir=FLAGS.model_dir)

  # Create an input function that reads the evaluation dataset.
  input_fn = estimator_util.create_input_fn(
      file_pattern=FLAGS.eval_files,
      input_config=config.inputs,
      mode=tf.estimator.ModeKeys.EVAL)

  # Run evaluation. This will log the result to stderr and also write a summary
  # file in the model_dir.
  estimator_util.evaluate(estimator, input_fn, eval_name=FLAGS.eval_name)
  def testZeroHiddenLayers(self):
    # Build config.
    feature_spec = {
        "time_feature_1": {
            "length": 10,
            "is_time_series": True,
        },
        "time_feature_2": {
            "length": 10,
            "is_time_series": True,
        },
        "aux_feature_1": {
            "length": 1,
            "is_time_series": False,
        },
    }
    config = configurations.base()
    config["inputs"]["features"] = feature_spec
    config = configdict.ConfigDict(config)
    config.hparams.output_dim = 1
    config.hparams.num_pre_logits_hidden_layers = 0

    # Build model.
    features = input_ops.build_feature_placeholders(config.inputs.features)
    labels = input_ops.build_labels_placeholder()
    model = astro_model.AstroModel(features, labels, config.hparams,
                                   tf.estimator.ModeKeys.TRAIN)
    model.build()

    # Validate Tensor shapes.
    self.assertShapeEquals((None, 21), model.pre_logits_concat)
    logits_w = testing.get_variable_by_name("logits/kernel")
    self.assertShapeEquals((21, 1), logits_w)
  def testOneTimeSeriesFeature(self):
    # Build config.
    feature_spec = {
        "time_feature_1": {
            "length": 10,
            "is_time_series": True,
        }
    }
    config = configurations.base()
    config["inputs"]["features"] = feature_spec
    config = configdict.ConfigDict(config)

    # Build model.
    features = input_ops.build_feature_placeholders(config.inputs.features)
    labels = input_ops.build_labels_placeholder()
    model = astro_model.AstroModel(features, labels, config.hparams,
                                   tf.estimator.ModeKeys.TRAIN)
    model.build()

    # Validate hidden layers.
    self.assertItemsEqual(["time_feature_1"],
                          model.time_series_hidden_layers.keys())
    self.assertIs(model.time_series_features["time_feature_1"],
                  model.time_series_hidden_layers["time_feature_1"])
    self.assertEqual(len(model.aux_hidden_layers), 0)
    self.assertIs(model.time_series_features["time_feature_1"],
                  model.pre_logits_concat)
Beispiel #7
0
def main(_):
  model_class = models.get_model_class(FLAGS.model)

  # Look up the model configuration.
  assert (FLAGS.config_name is None) != (FLAGS.config_json is None), (
      "Exactly one of --config_name or --config_json is required.")
  config = (
      models.get_model_config(FLAGS.model, FLAGS.config_name)
      if FLAGS.config_name else config_util.parse_json(FLAGS.config_json))

  config = configdict.ConfigDict(config)

  # Create the estimator.
  estimator = estimator_util.create_estimator(
      model_class, config.hparams, model_dir=FLAGS.model_dir)
  
  # Print no. of trainable parameters to console.
  var_names = [v for v in estimator.get_variable_names()]
  n_params = np.sum([len(estimator.get_variable_value(v).flatten()) for v in var_names])
  print("Trainable parameters in model:", int(n_params))
  
  # Create an input function that reads the evaluation dataset.
  input_fn = estimator_util.create_input_fn(
      file_pattern=FLAGS.eval_files,
      input_config=config.inputs,
      mode=tf.estimator.ModeKeys.EVAL)

  # Run evaluation. This will log the result to stderr and also write a summary
  # file in the model_dir.
  estimator_util.evaluate(estimator, input_fn, eval_name=FLAGS.eval_name)
Beispiel #8
0
  def testTwoTimeSeriesFeatures(self):
    # Build config.
    feature_spec = {
        "time_feature_1": {
            "length": 20,
            "is_time_series": True,
        },
        "time_feature_2": {
            "length": 5,
            "is_time_series": True,
        },
        "aux_feature_1": {
            "length": 1,
            "is_time_series": False,
        },
    }
    hidden_spec = {
        "time_feature_1": {
            "rnn_num_layers": 2,
            "rnn_num_units": 16,
            "rnn_memory_cells": "lstm",
            "rnn_activation": "tanh",
            "rnn_dropout": 0.0,
            "rnn_direction": "bi"
        },
        "time_feature_2": {
            "rnn_num_layers": 1,
            "rnn_num_units": 4,
            "rnn_memory_cells": "lstm",
            "rnn_activation": "tanh",
            "rnn_dropout": 0.0,
            "rnn_direction": "bi"
        }
    }
    config = configurations.base()
    config["inputs"]["features"] = feature_spec
    config["hparams"]["time_series_hidden"] = hidden_spec
    config = configdict.ConfigDict(config)
    
    # Build model
    features = input_ops.build_feature_placeholders(config.inputs.features)
    labels = input_ops.build_labels_placeholder()
    model = astro_rnn_model.AstroRNNModel(features, labels, config.hparams,
                                          tf.estimator.ModeKeys.TRAIN)
    model.build()
    
    # Execute the TensorFlow graph.
    scaffold = tf.train.Scaffold()
    scaffold.finalize()
    with self.test_session() as sess:
      sess.run([scaffold.init_op, scaffold.local_init_op])
      step = sess.run(model.global_step)
      self.assertEqual(0, step)

      # Fetch predictions.
      features = testing.fake_features(feature_spec, batch_size=16)
      labels = testing.fake_labels(config.hparams.output_dim, batch_size=16)
      feed_dict = input_ops.prepare_feed_dict(model, features, labels)
      predictions = sess.run(model.predictions, feed_dict=feed_dict)
      self.assertShapeEquals((16, 1), predictions)
Beispiel #9
0
def main(_):
    model_class = models.get_model_class(FLAGS.model)

    # Look up the model configuration.
    assert (FLAGS.config_name is None) != (FLAGS.config_json is None), (
        "Exactly one of --config_name or --config_json is required.")
    config = (models.get_model_config(FLAGS.model, FLAGS.config_name) if
              FLAGS.config_name else config_util.parse_json(FLAGS.config_json))
    config = configdict.ConfigDict(config)

    # Create the estimator.
    estimator = estimator_util.create_estimator(model_class,
                                                config.hparams,
                                                model_dir=FLAGS.model_dir)

    # Read and process the input features.
    features = _process_tce(config.inputs.features)
    print(type(features))
    print(features)

    # Create an input function.
    def input_fn():
        return {
            "time_series_features":
            tf.estimator.inputs.numpy_input_fn(features,
                                               batch_size=1,
                                               shuffle=False,
                                               queue_capacity=1)()
        }

    # Generate the predictions.
    for predictions in estimator.predict(input_fn):
        assert len(predictions) == 1
        print("Prediction:", predictions[0])
  def testInvalidModeRaisesError(self):
    # Build config.
    config = configdict.ConfigDict(configurations.base())

    # Build model.
    features = input_ops.build_feature_placeholders(config.inputs.features)
    labels = input_ops.build_labels_placeholder()
    with self.assertRaises(ValueError):
      _ = astro_model.AstroModel(features, labels, config.hparams, "training")
Beispiel #11
0
def load_config(output_dir):
    """Parses values from a JSON file.
  Args:
    json_file: The path to a JSON file.
  Returns:
    A dictionary; the parsed JSON.
  """
    with tf.io.gfile.GFile(config_file(output_dir), 'r') as f:
        return configdict.ConfigDict(json.loads(f.read()))
Beispiel #12
0
  def __init__(self, config_overrides=None):
    """Initializes the dataset builder.

    Args:
      config_overrides: Dict or ConfigDict containing overrides to the default
        configuration.
    """
    self.config = configdict.ConfigDict(self.default_config())
    if config_overrides is not None:
      self.config.update(config_overrides)
Beispiel #13
0
 def default_config():
   return configdict.ConfigDict({
       "period_range": (0.5, 4),
       "amplitude_range": (1, 1),
       "threshold_ratio_range": (0, 0.99),
       "phase_range": (0, 1),
       "noise_sd_range": (0.1, 0.1),
       "mask_probability": 0.1,
       "light_curve_time_range": (0, 100),
       "light_curve_num_points": 1000
   })
Beispiel #14
0
    def testEvalMode(self):
        # Build config.
        feature_spec = {
            "time_feature_1": {
                "length": 10,
                "is_time_series": True,
            },
            "time_feature_2": {
                "length": 10,
                "is_time_series": True,
            },
            "aux_feature_1": {
                "length": 1,
                "is_time_series": False,
            },
        }
        config = configurations.base()
        config["inputs"]["features"] = feature_spec
        config = configdict.ConfigDict(config)
        config.hparams.output_dim = 1

        # Build model.
        features = input_ops.build_feature_placeholders(config.inputs.features)
        labels = input_ops.build_labels_placeholder()
        model = astro_model.AstroModel(features, labels, config.hparams,
                                       tf.estimator.ModeKeys.TRAIN)
        model.build()

        # Validate Tensor shapes.
        self.assertShapeEquals((None, 21), model.pre_logits_concat)
        self.assertShapeEquals((None, 1), model.logits)
        self.assertShapeEquals((None, 1), model.predictions)
        self.assertShapeEquals((None, ), model.batch_losses)
        self.assertShapeEquals((), model.total_loss)

        # Execute the TensorFlow graph.
        scaffold = tf.train.Scaffold()
        scaffold.finalize()
        with self.test_session() as sess:
            sess.run([scaffold.init_op, scaffold.local_init_op])
            step = sess.run(model.global_step)
            self.assertEqual(0, step)

            # Fetch total loss.
            features = testing.fake_features(feature_spec, batch_size=16)
            labels = testing.fake_labels(config.hparams.output_dim,
                                         batch_size=16)
            feed_dict = input_ops.prepare_feed_dict(model, features, labels)
            total_loss = sess.run(model.total_loss, feed_dict=feed_dict)
            self.assertShapeEquals((), total_loss)
  def testZeroFeaturesRaisesError(self):
    # Build config.
    config = configurations.base()
    config["inputs"]["features"] = {}
    config = configdict.ConfigDict(config)

    # Build model.
    features = input_ops.build_feature_placeholders(config.inputs.features)
    labels = input_ops.build_labels_placeholder()
    model = astro_model.AstroModel(features, labels, config.hparams,
                                   tf.estimator.ModeKeys.TRAIN)
    with self.assertRaises(ValueError):
      # Raises ValueError because at least one feature is required.
      model.build()
Beispiel #16
0
def main(_):
    model_class = models.get_model_class(FLAGS.model)

    # Look up the model configuration.
    assert (FLAGS.config_name is None) != (FLAGS.config_json is None), (
        "Exactly one of --config_name or --config_json is required.")
    config = (models.get_model_config(FLAGS.model, FLAGS.config_name) if
              FLAGS.config_name else config_util.parse_json(FLAGS.config_json))

    config = configdict.ConfigDict(config)
    config_util.log_and_save_config(config, FLAGS.model_dir)

    # Create the estimator.
    run_config = tf.estimator.RunConfig(keep_checkpoint_max=1)
    estimator = estimator_util.create_estimator(model_class, config.hparams,
                                                run_config, FLAGS.model_dir)

    # Create an input function that reads the training dataset. We iterate through
    # the dataset once at a time if we are alternating with evaluation, otherwise
    # we iterate infinitely.
    train_input_fn = estimator_util.create_input_fn(
        file_pattern=FLAGS.train_files,
        input_config=config.inputs,
        mode=tf.estimator.ModeKeys.TRAIN,
        shuffle_values_buffer=FLAGS.shuffle_buffer_size,
        repeat=1 if FLAGS.eval_files else None)

    if not FLAGS.eval_files:
        estimator.train(train_input_fn, max_steps=FLAGS.train_steps)
    else:
        eval_input_fn = estimator_util.create_input_fn(
            file_pattern=FLAGS.eval_files,
            input_config=config.inputs,
            mode=tf.estimator.ModeKeys.EVAL)
        eval_args = {
            "val": (eval_input_fn, None)  # eval_name: (input_fn, eval_steps)
        }

        for _ in estimator_runner.continuous_train_and_eval(
                estimator=estimator,
                train_input_fn=train_input_fn,
                eval_args=eval_args,
                train_steps=FLAGS.train_steps):
            # continuous_train_and_eval() yields evaluation metrics after each
            # training epoch. We don't do anything here.
            pass
def main(_):
    model_class = models.get_model_class(FLAGS.model)

    # Look up the model configuration.
    assert (FLAGS.config_name is None) != (FLAGS.config_json is None), (
        "Exactly one of --config_name or --config_json is required.")
    config = (models.get_model_config(FLAGS.model, FLAGS.config_name) if
              FLAGS.config_name else config_util.parse_json(FLAGS.config_json))
    config = configdict.ConfigDict(config)

    # Create the estimator.
    estimator = estimator_util.create_estimator(model_class,
                                                config.hparams,
                                                model_dir=FLAGS.model_dir)

    # Read and process the input features.
    tce_table = pd.read_csv(FLAGS.input_tce_csv_file,
                            header=0,
                            usecols=[0, 1, 2, 3, 7, 8, 9, 10, 11, 12, 13, 18],
                            dtype={
                                'Sectors': int,
                                'camera': int,
                                'ccd': int
                            })

    for ind, tce in tce_table.iterrows():
        features = _process_tce(config.inputs.features, tce)

        # Create an input function.
        def input_fn():
            return {
                "time_series_features":
                tf.compat.v1.estimator.inputs.numpy_input_fn(
                    features, batch_size=1, shuffle=False, queue_capacity=1)()
            }

        # Generate the predictions.
        for predictions in estimator.predict(input_fn):
            assert len(predictions) == 1
            print(tce.tic_id, "Prediction:", predictions[0])

            print(str(tce.tic_id) + ' ' + str(predictions[0]),
                  file=open(FLAGS.output_file, 'a'))
 def setUp(self):
     super(ConfigDictTest, self).setUp()
     self._config = configdict.ConfigDict({
         "int": 1,
         "float": 2.0,
         "bool": True,
         "str": "hello",
         "nested": {
             "int": 3,
         },
         "double_nested": {
             "a": {
                 "int": 3,
             },
             "b": {
                 "float": 4.0,
             }
         }
     })
Beispiel #19
0
def create_config(model_name="AstroCNNModel", config_name="local_global"):
    config = (models.get_model_config(model_name, config_name)
              if config_name else config_util.parse_json("config_json"))
    return configdict.ConfigDict(config)
Beispiel #20
0
    def testOneTimeSeriesFeature(self):
        # Build config.
        feature_spec = {
            "time_feature_1": {
                "length": 20,
                "is_time_series": True,
            }
        }
        hidden_spec = {
            "time_feature_1": {
                "cnn_num_blocks": 2,
                "cnn_block_size": 2,
                "cnn_initial_num_filters": 4,
                "cnn_block_filter_factor": 1.5,
                "cnn_kernel_size": 3,
                "convolution_padding": "same",
                "pool_size": 2,
                "pool_strides": 2,
            }
        }
        config = configurations.base()
        config["inputs"]["features"] = feature_spec
        config["hparams"]["time_series_hidden"] = hidden_spec
        config = configdict.ConfigDict(config)

        # Build model.
        features = input_ops.build_feature_placeholders(config.inputs.features)
        labels = input_ops.build_labels_placeholder()
        model = astro_cnn_model.AstroCNNModel(features, labels, config.hparams,
                                              tf.estimator.ModeKeys.TRAIN)
        model.build()

        # Validate Tensor shapes.
        block_1_conv_1 = testing.get_variable_by_name(
            "time_feature_1_hidden/block_1/conv_1/kernel")
        self.assertShapeEquals((3, 1, 4), block_1_conv_1)

        block_1_conv_2 = testing.get_variable_by_name(
            "time_feature_1_hidden/block_1/conv_2/kernel")
        self.assertShapeEquals((3, 4, 4), block_1_conv_2)

        block_2_conv_1 = testing.get_variable_by_name(
            "time_feature_1_hidden/block_2/conv_1/kernel")
        self.assertShapeEquals((3, 4, 6), block_2_conv_1)

        block_2_conv_2 = testing.get_variable_by_name(
            "time_feature_1_hidden/block_2/conv_2/kernel")
        self.assertShapeEquals((3, 6, 6), block_2_conv_2)

        self.assertItemsEqual(["time_feature_1"],
                              model.time_series_hidden_layers.keys())
        self.assertShapeEquals(
            (None, 30), model.time_series_hidden_layers["time_feature_1"])
        self.assertEqual(len(model.aux_hidden_layers), 0)
        self.assertIs(model.time_series_hidden_layers["time_feature_1"],
                      model.pre_logits_concat)

        # Execute the TensorFlow graph.
        scaffold = tf.train.Scaffold()
        scaffold.finalize()
        with self.test_session() as sess:
            sess.run([scaffold.init_op, scaffold.local_init_op])
            step = sess.run(model.global_step)
            self.assertEqual(0, step)

            # Fetch predictions.
            features = testing.fake_features(feature_spec, batch_size=16)
            labels = testing.fake_labels(config.hparams.output_dim,
                                         batch_size=16)
            feed_dict = input_ops.prepare_feed_dict(model, features, labels)
            predictions = sess.run(model.predictions, feed_dict=feed_dict)
            self.assertShapeEquals((16, 1), predictions)
Beispiel #21
0
def main(_):
    # Look up the model class.
    model_class = models.get_model_class(FLAGS.model)

    # Look up the model configuration.
    if (FLAGS.config_name is None) == (FLAGS.config_json is None):
        raise ValueError(
            "Exactly one of config_name or config_json is required.")
    config = (models.get_model_config(FLAGS.model, FLAGS.config_name) if
              FLAGS.config_name else config_util.parse_json(FLAGS.config_json))
    config = configdict.ConfigDict(config)

    if FLAGS.average:
        model_dirs = glob.glob(FLAGS.model_dir + '/*')
    else:
        model_dirs = [FLAGS.model_dir]

    y_pred = defaultdict(list)
    true_disp = defaultdict(list)
    for model_dir in model_dirs:
        # append a new prediction to y_pred for each TCE every time
        y_pred, true_disp = predict(model_dir, config, model_class, y_pred,
                                    true_disp)

    y_pred_average = []
    is_pc = []
    cnt = 0
    num_tces = len(y_pred)

    for tce in y_pred:
        # Plotting takes quite long. There's probably a better way to do it.
        if cnt % 100 == 0:
            tf.logging.info("Averaging %d of %d", cnt, num_tces)

        disposition = true_disp[tce]
        y_true = disposition in ['PC', 'EB']
        is_pc.append(disposition == 'PC')
        average_pred = np.mean(y_pred[tce])

        if average_pred >= 0.1:
            label = "PC/EB"
        else:
            label = "junk"

        if FLAGS.plot and disposition == 'PC' and average_pred < 0.1:
            kepid, sector = tce.split('-')
            ex = find_tce(int(kepid), int(sector))
            plot_tce(ex.features.feature["tic_id"].int64_list.value[0],
                     ex.features.feature['Sectors'].int64_list.value[0], label,
                     average_pred)

        y_pred_average.append([y_true, average_pred])
        cnt += 1

    threshold = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
    y_true = np.array(y_pred_average)[:, 0]
    y_pred = np.array(y_pred_average)[:, 1]
    is_pc = np.array(is_pc)

    for t in threshold:
        tp = len(np.where((y_true == 1) & (y_pred >= t))[0])
        fp = len(np.where((y_true == 0) & (y_pred >= t))[0])
        fn = len(np.where((y_true == 1) & (y_pred < t))[0])
        tn = len(np.where((y_true == 0) & (y_pred < t))[0])
        precision = float(tp) / (tp + fp)
        recall = float(tp) / (tp + fn)
        print(1 - float(fp) / (tn + fp))

        num_pc = len(np.where((is_pc == True) & (y_pred < t))[0])

        print(
            "Threshold %s: precision=%s, recall=%s. Number of PCs in FNs = %s"
            % (t, precision, recall, num_pc))

    np.savetxt('true_vs_pred_' + FLAGS.suffix + '.txt',
               np.array(y_pred_average),
               fmt=['%f', '%4.3f'])
    pc_count = len(np.where(is_pc == True)[0])
    print('Total %s PCs' % pc_count)
Beispiel #22
0
    def testBuildFeaturePlaceholders(self):
        # One time series feature.
        config = configdict.ConfigDict(
            {"time_feature_1": {
                "length": 14,
                "is_time_series": True,
            }})
        expected_shapes = {
            "time_series_features": {
                "time_feature_1": [None, 14],
            },
            "aux_features": {}
        }
        features = input_ops.build_feature_placeholders(config)
        self.assertFeatureShapesEqual(expected_shapes, features)

        # Two time series features.
        config = configdict.ConfigDict({
            "time_feature_1": {
                "length": 14,
                "is_time_series": True,
            },
            "time_feature_2": {
                "length": 5,
                "is_time_series": True,
            }
        })
        expected_shapes = {
            "time_series_features": {
                "time_feature_1": [None, 14],
                "time_feature_2": [None, 5],
            },
            "aux_features": {}
        }
        features = input_ops.build_feature_placeholders(config)
        self.assertFeatureShapesEqual(expected_shapes, features)

        # One aux feature.
        config = configdict.ConfigDict({
            "time_feature_1": {
                "length": 14,
                "is_time_series": True,
            },
            "aux_feature_1": {
                "length": 1,
                "is_time_series": False,
            }
        })
        expected_shapes = {
            "time_series_features": {
                "time_feature_1": [None, 14],
            },
            "aux_features": {
                "aux_feature_1": [None, 1]
            }
        }
        features = input_ops.build_feature_placeholders(config)
        self.assertFeatureShapesEqual(expected_shapes, features)

        # Two aux features.
        config = configdict.ConfigDict({
            "time_feature_1": {
                "length": 14,
                "is_time_series": True,
            },
            "aux_feature_1": {
                "length": 1,
                "is_time_series": False,
            },
            "aux_feature_2": {
                "length": 6,
                "is_time_series": False,
            },
        })
        expected_shapes = {
            "time_series_features": {
                "time_feature_1": [None, 14],
            },
            "aux_features": {
                "aux_feature_1": [None, 1],
                "aux_feature_2": [None, 6]
            }
        }
        features = input_ops.build_feature_placeholders(config)
        self.assertFeatureShapesEqual(expected_shapes, features)
    def test_build_model_categorical(self):
        time_series_length = 9
        input_num_features = 8
        context_num_features = 7

        input_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, input_num_features],
            name="input")
        context_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, context_num_features],
            name="context")
        features = {
            "autoregressive_input": input_placeholder,
            "conditioning_stack": context_placeholder
        }
        mode = tf.estimator.ModeKeys.TRAIN
        hparams = configdict.ConfigDict({
            "dilation_kernel_width": 2,
            "skip_output_dim": 6,
            "preprocess_output_size": 3,
            "preprocess_kernel_width": 5,
            "num_residual_blocks": 2,
            "dilation_rates": [1, 2, 4],
            "output_distribution": {
                "type": "categorical",
                "num_classes": 256,
                "min_quantization_value": -1,
                "max_quantization_value": 1
            }
        })

        model = astrowavenet_model.AstroWaveNet(features, hparams, mode)
        model.build()

        variables = {v.op.name: v for v in tf.trainable_variables()}

        var = variables["dist_params/conv1x1/kernel"]
        self.assertShapeEquals(
            (1, hparams.skip_output_dim,
             hparams.output_distribution.num_classes * input_num_features),
            var)
        var = variables["dist_params/conv1x1/bias"]
        self.assertShapeEquals(
            (hparams.output_distribution.num_classes * input_num_features, ),
            var)

        # Verify model runs and outputs losses of correct shape.

        scaffold = tf.train.Scaffold()
        scaffold.finalize()
        with self.cached_session() as sess:
            sess.run([scaffold.init_op, scaffold.local_init_op])
            step = sess.run(model.global_step)
            self.assertEqual(0, step)

            batch_size = 11
            feed_dict = {
                input_placeholder:
                np.random.random(
                    (batch_size, time_series_length, input_num_features)),
                context_placeholder:
                np.random.random(
                    (batch_size, time_series_length, context_num_features))
            }
            batch_losses, per_example_loss, total_loss = sess.run(
                [model.batch_losses, model.per_example_loss, model.total_loss],
                feed_dict=feed_dict)
            self.assertShapeEquals(
                (batch_size, time_series_length, input_num_features),
                batch_losses)
            self.assertShapeEquals((batch_size, ), per_example_loss)
            self.assertShapeEquals((), total_loss)
    def test_output_categorical(self):
        time_series_length = 3
        input_num_features = 1
        context_num_features = 7
        num_classes = 4  # For quantized categorical output predictions.

        input_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, input_num_features],
            name="input")
        context_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, context_num_features],
            name="context")
        features = {
            "autoregressive_input": input_placeholder,
            "conditioning_stack": context_placeholder
        }
        mode = tf.estimator.ModeKeys.TRAIN
        hparams = configdict.ConfigDict({
            "dilation_kernel_width": 2,
            "skip_output_dim": 6,
            "preprocess_output_size": 3,
            "preprocess_kernel_width": 5,
            "num_residual_blocks": 2,
            "dilation_rates": [1, 2, 4],
            "output_distribution": {
                "type": "categorical",
                "min_scale": 0,
                "num_classes": num_classes,
                "min_quantization_value": 0,
                "max_quantization_value": 1
            }
        })

        model = astrowavenet_model.AstroWaveNet(features, hparams, mode)
        model.build()

        self.assertItemsEqual(["logits"], model.dist_params.keys())
        self.assertShapeEquals(
            (None, time_series_length, input_num_features, num_classes),
            model.dist_params["logits"])

        scaffold = tf.train.Scaffold()
        scaffold.finalize()
        with self.cached_session() as sess:
            sess.run([scaffold.init_op, scaffold.local_init_op])
            step = sess.run(model.global_step)
            self.assertEqual(0, step)

            feed_dict = {
                input_placeholder: [
                    [[0], [0], [0]],  # min_quantization_value
                    [[0.2], [0.2], [0.2]],  # Within bucket.
                    [[0.25], [0.25], [0.25]],  # On bucket boundary.
                    [[0.5], [0.5], [0.5]],  # On bucket boundary.
                    [[0.8], [0.8], [0.8]],  # Within bucket.
                    [[1], [1], [1]],  # max_quantization_value
                    [[-0.1], [1.5], [200]],  # Outside range: will be clipped.
                ],
                # Context is not needed since we explicitly feed the dist params.
                model.dist_params["logits"]: [
                    [[[1, 0, 0, 0]], [[0, 1, 0, 0]], [[0, 0, 0, 1]]],
                    [[[1, 0, 0, 0]], [[0, 1, 0, 0]], [[0, 0, 0, 1]]],
                    [[[0, 1, 0, 0]], [[1, 0, 0, 0]], [[0, 0, 1, 0]]],
                    [[[0, 0, 1, 0]], [[0, 1, 0, 0]], [[0, 0, 0, 1]]],
                    [[[0, 0, 0, 1]], [[1, 0, 0, 0]], [[1, 0, 0, 0]]],
                    [[[0, 0, 0, 1]], [[0, 1, 0, 0]], [[0, 0, 1, 0]]],
                    [[[1, 0, 0, 0]], [[0, 0, 1, 0]], [[0, 1, 0, 0]]],
                ],
            }
            (target, batch_losses, per_example_loss, num_examples,
             total_loss) = sess.run([
                 model.autoregressive_target, model.batch_losses,
                 model.per_example_loss, model.num_nonzero_weight_examples,
                 model.total_loss
             ],
                                    feed_dict=feed_dict)
            np.testing.assert_array_almost_equal([
                [[0], [0], [0]],
                [[0], [0], [0]],
                [[1], [1], [1]],
                [[2], [2], [2]],
                [[3], [3], [3]],
                [[3], [3], [3]],
                [[0], [3], [3]],
            ], target)
            np.testing.assert_array_almost_equal([
                [[0.74366838], [1.74366838], [1.74366838]],
                [[0.74366838], [1.74366838], [1.74366838]],
                [[0.74366838], [1.74366838], [1.74366838]],
                [[0.74366838], [1.74366838], [1.74366838]],
                [[0.74366838], [1.74366838], [1.74366838]],
                [[0.74366838], [1.74366838], [1.74366838]],
                [[0.74366838], [1.74366838], [1.74366838]],
            ], batch_losses)
            np.testing.assert_array_almost_equal([
                1.41033504, 1.41033504, 1.41033504, 1.41033504, 1.41033504,
                1.41033504, 1.41033504
            ], per_example_loss)
            np.testing.assert_almost_equal(7, num_examples)
            np.testing.assert_almost_equal(1.41033504, total_loss)
    def test_build_model(self):
        time_series_length = 9
        input_num_features = 8
        context_num_features = 7

        input_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, input_num_features],
            name="input")
        context_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, context_num_features],
            name="context")
        features = {
            "autoregressive_input": input_placeholder,
            "conditioning_stack": context_placeholder
        }
        mode = tf.estimator.ModeKeys.TRAIN
        hparams = configdict.ConfigDict({
            "dilation_kernel_width": 2,
            "skip_output_dim": 6,
            "preprocess_output_size": 3,
            "preprocess_kernel_width": 5,
            "num_residual_blocks": 2,
            "dilation_rates": [1, 2, 4],
            "output_distribution": {
                "type": "normal",
                "min_scale": 0.001,
            }
        })

        model = astrowavenet_model.AstroWaveNet(features, hparams, mode)
        model.build()

        variables = {v.op.name: v for v in tf.trainable_variables()}

        # Verify variable shapes in two residual blocks.

        var = variables["preprocess/causal_conv/kernel"]
        self.assertShapeEquals((5, 8, 3), var)
        var = variables["preprocess/causal_conv/bias"]
        self.assertShapeEquals((3, ), var)

        var = variables["block_0/dilation_1/filter/causal_conv/kernel"]
        self.assertShapeEquals((2, 3, 3), var)
        var = variables["block_0/dilation_1/filter/causal_conv/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_0/dilation_1/filter/conv1x1/kernel"]
        self.assertShapeEquals((1, 7, 3), var)
        var = variables["block_0/dilation_1/filter/conv1x1/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_0/dilation_1/gate/causal_conv/kernel"]
        self.assertShapeEquals((2, 3, 3), var)
        var = variables["block_0/dilation_1/gate/causal_conv/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_0/dilation_1/gate/conv1x1/kernel"]
        self.assertShapeEquals((1, 7, 3), var)
        var = variables["block_0/dilation_1/gate/conv1x1/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_0/dilation_1/residual/conv1x1/kernel"]
        self.assertShapeEquals((1, 3, 3), var)
        var = variables["block_0/dilation_1/residual/conv1x1/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_0/dilation_1/skip/conv1x1/kernel"]
        self.assertShapeEquals((1, 3, 6), var)
        var = variables["block_0/dilation_1/skip/conv1x1/bias"]
        self.assertShapeEquals((6, ), var)

        var = variables["block_1/dilation_4/filter/causal_conv/kernel"]
        self.assertShapeEquals((2, 3, 3), var)
        var = variables["block_1/dilation_4/filter/causal_conv/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_1/dilation_4/filter/conv1x1/kernel"]
        self.assertShapeEquals((1, 7, 3), var)
        var = variables["block_1/dilation_4/filter/conv1x1/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_1/dilation_4/gate/causal_conv/kernel"]
        self.assertShapeEquals((2, 3, 3), var)
        var = variables["block_1/dilation_4/gate/causal_conv/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_1/dilation_4/gate/conv1x1/kernel"]
        self.assertShapeEquals((1, 7, 3), var)
        var = variables["block_1/dilation_4/gate/conv1x1/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_1/dilation_4/residual/conv1x1/kernel"]
        self.assertShapeEquals((1, 3, 3), var)
        var = variables["block_1/dilation_4/residual/conv1x1/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_1/dilation_4/skip/conv1x1/kernel"]
        self.assertShapeEquals((1, 3, 6), var)
        var = variables["block_1/dilation_4/skip/conv1x1/bias"]
        self.assertShapeEquals((6, ), var)

        var = variables["postprocess/conv1x1/kernel"]
        self.assertShapeEquals((1, 6, 6), var)
        var = variables["postprocess/conv1x1/bias"]
        self.assertShapeEquals((6, ), var)
        var = variables["dist_params/conv1x1/kernel"]
        self.assertShapeEquals((1, 6, 16), var)
        var = variables["dist_params/conv1x1/bias"]
        self.assertShapeEquals((16, ), var)

        # Verify total number of trainable parameters.

        num_preprocess_params = (
            hparams.preprocess_kernel_width * input_num_features *
            hparams.preprocess_output_size + hparams.preprocess_output_size)

        num_gated_params = (
            hparams.dilation_kernel_width * hparams.preprocess_output_size *
            hparams.preprocess_output_size + hparams.preprocess_output_size +
            1 * context_num_features * hparams.preprocess_output_size +
            hparams.preprocess_output_size) * 2
        num_residual_params = (1 * hparams.preprocess_output_size *
                               hparams.preprocess_output_size +
                               hparams.preprocess_output_size)
        num_skip_params = (
            1 * hparams.preprocess_output_size * hparams.skip_output_dim +
            hparams.skip_output_dim)
        num_block_params = (
            num_gated_params + num_residual_params + num_skip_params) * len(
                hparams.dilation_rates) * hparams.num_residual_blocks

        num_postprocess_params = (
            1 * hparams.skip_output_dim * hparams.skip_output_dim +
            hparams.skip_output_dim)

        num_dist_params = (
            1 * hparams.skip_output_dim * 2 * input_num_features +
            2 * input_num_features)

        total_params = (num_preprocess_params + num_block_params +
                        num_postprocess_params + num_dist_params)

        total_retrieved_params = 0
        for v in tf.trainable_variables():
            total_retrieved_params += np.prod(v.shape)

        self.assertEqual(total_params, total_retrieved_params)

        # Verify model runs and outputs losses of correct shape.

        scaffold = tf.train.Scaffold()
        scaffold.finalize()
        with self.cached_session() as sess:
            sess.run([scaffold.init_op, scaffold.local_init_op])
            step = sess.run(model.global_step)
            self.assertEqual(0, step)

            batch_size = 11
            feed_dict = {
                input_placeholder:
                np.random.random(
                    (batch_size, time_series_length, input_num_features)),
                context_placeholder:
                np.random.random(
                    (batch_size, time_series_length, context_num_features))
            }
            batch_losses, per_example_loss, total_loss = sess.run(
                [model.batch_losses, model.per_example_loss, model.total_loss],
                feed_dict=feed_dict)
            self.assertShapeEquals(
                (batch_size, time_series_length, input_num_features),
                batch_losses)
            self.assertShapeEquals((batch_size, ), per_example_loss)
            self.assertShapeEquals((), total_loss)
    def test_output_weighted(self):
        time_series_length = 6
        input_num_features = 2
        context_num_features = 7

        input_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, input_num_features],
            name="input")
        weights_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, input_num_features],
            name="input")
        context_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, context_num_features],
            name="context")
        features = {
            "autoregressive_input": input_placeholder,
            "weights": weights_placeholder,
            "conditioning_stack": context_placeholder
        }
        mode = tf.estimator.ModeKeys.TRAIN
        hparams = configdict.ConfigDict({
            "dilation_kernel_width": 2,
            "skip_output_dim": 6,
            "preprocess_output_size": 3,
            "preprocess_kernel_width": 5,
            "num_residual_blocks": 2,
            "dilation_rates": [1, 2, 4],
            "output_distribution": {
                "type": "normal",
                "min_scale": 0,
            }
        })

        model = astrowavenet_model.AstroWaveNet(features, hparams, mode)
        model.build()

        scaffold = tf.train.Scaffold()
        scaffold.finalize()
        with self.cached_session() as sess:
            sess.run([scaffold.init_op, scaffold.local_init_op])
            step = sess.run(model.global_step)
            self.assertEqual(0, step)

            feed_dict = {
                input_placeholder: [
                    [[1, 9], [1, 9], [1, 9], [1, 9], [1, 9], [1, 9]],
                    [[2, 8], [2, 8], [2, 8], [2, 8], [2, 8], [2, 8]],
                    [[3, 7], [3, 7], [3, 7], [3, 7], [3, 7], [3, 7]],
                ],
                weights_placeholder: [
                    [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1]],
                    [[1, 0], [1, 1], [1, 1], [0, 1], [0, 1], [0, 0]],
                    [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]],
                ],
                # Context is not needed since we explicitly feed the dist params.
                model.dist_params["loc"]: [
                    [[1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8]],
                    [[2, 9], [2, 9], [2, 9], [2, 9], [2, 9], [2, 9]],
                    [[3, 6], [3, 6], [3, 6], [3, 6], [3, 6], [3, 6]],
                ],
                model.dist_params["scale"]: [
                    [[0.1, 0.1], [0.2, 0.2], [0.5, 0.5], [1, 1], [2, 2],
                     [5, 5]],
                    [[0.1, 0.1], [0.2, 0.2], [0.5, 0.5], [1, 1], [2, 2],
                     [5, 5]],
                    [[0.1, 0.1], [0.2, 0.2], [0.5, 0.5], [1, 1], [2, 2],
                     [5, 5]],
                ],
            }
            batch_losses, per_example_loss, num_examples, total_loss = sess.run(
                [
                    model.batch_losses, model.per_example_loss,
                    model.num_nonzero_weight_examples, model.total_loss
                ],
                feed_dict=feed_dict)
            np.testing.assert_array_almost_equal(
                [[[-1.38364656, 48.61635344], [-0.69049938, 11.80950062],
                  [0.22579135, 2.22579135], [0.91893853, 1.41893853],
                  [1.61208571, 1.73708571], [2.52837645, 2.54837645]],
                 [[-1.38364656, 0], [-0.69049938, 11.80950062],
                  [0.22579135, 2.22579135], [0, 1.41893853], [0, 1.73708571],
                  [0, 0]], [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]],
                batch_losses)
            np.testing.assert_array_almost_equal([5.96392435, 2.19185166, 0],
                                                 per_example_loss)
            np.testing.assert_almost_equal(2, num_examples)
            np.testing.assert_almost_equal(4.07788801, total_loss)
    def test_causality(self):
        time_series_length = 7
        input_num_features = 1
        context_num_features = 1

        input_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, input_num_features],
            name="input")
        context_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, context_num_features],
            name="context")
        features = {
            "autoregressive_input": input_placeholder,
            "conditioning_stack": context_placeholder
        }
        mode = tf.estimator.ModeKeys.TRAIN
        hparams = configdict.ConfigDict({
            "dilation_kernel_width": 1,
            "skip_output_dim": 1,
            "preprocess_output_size": 1,
            "preprocess_kernel_width": 1,
            "num_residual_blocks": 1,
            "dilation_rates": [1],
            "output_distribution": {
                "type": "normal",
                "min_scale": 0.001,
            }
        })

        model = astrowavenet_model.AstroWaveNet(features, hparams, mode)
        model.build()

        scaffold = tf.train.Scaffold()
        scaffold.finalize()
        with self.cached_session() as sess:
            sess.run([scaffold.init_op, scaffold.local_init_op])
            step = sess.run(model.global_step)
            self.assertEqual(0, step)

            feed_dict = {
                input_placeholder: [
                    [[0], [0], [0], [0], [0], [0], [0]],
                    [[1], [0], [0], [0], [0], [0], [0]],
                    [[0], [0], [0], [1], [0], [0], [0]],
                    [[0], [0], [0], [0], [0], [0], [1]],
                    [[0], [0], [0], [0], [0], [0], [0]],
                    [[0], [0], [0], [0], [0], [0], [0]],
                    [[0], [0], [0], [0], [0], [0], [0]],
                ],
                context_placeholder: [
                    [[0], [0], [0], [0], [0], [0], [0]],
                    [[0], [0], [0], [0], [0], [0], [0]],
                    [[0], [0], [0], [0], [0], [0], [0]],
                    [[0], [0], [0], [0], [0], [0], [0]],
                    [[1], [0], [0], [0], [0], [0], [0]],
                    [[0], [0], [0], [1], [0], [0], [0]],
                    [[0], [0], [0], [0], [0], [0], [1]],
                ],
            }
            network_output = sess.run(model.network_output,
                                      feed_dict=feed_dict)
            np.testing.assert_array_equal(
                [
                    [[0], [0], [0], [0], [0], [0], [0]],
                    # Input elements are used to predict the next timestamp.
                    [[0], [1], [0], [0], [0], [0], [0]],
                    [[0], [0], [0], [0], [1], [0], [0]],
                    [[0], [0], [0], [0], [0], [0], [0]],
                    # Context elements are used to predict the current timestamp.
                    [[1], [0], [0], [0], [0], [0], [0]],
                    [[0], [0], [0], [1], [0], [0], [0]],
                    [[0], [0], [0], [0], [0], [0], [1]],
                ],
                np.greater(np.abs(network_output), 0))
Beispiel #28
0
    def testOneTimeSeriesFeature(self):
        # Build config.
        feature_spec = {
            "time_feature_1": {
                "length": 14,
                "is_time_series": True,
            }
        }
        hidden_spec = {
            "time_feature_1": {
                "num_local_layers": 2,
                "local_layer_size": 20,
                "translation_delta": 2,
                "pooling_type": "max",
                "dropout_rate": 0.5,
            }
        }
        config = configurations.base()
        config["inputs"]["features"] = feature_spec
        config["hparams"]["time_series_hidden"] = hidden_spec
        config = configdict.ConfigDict(config)

        # Build model.
        features = input_ops.build_feature_placeholders(config.inputs.features)
        labels = input_ops.build_labels_placeholder()
        model = astro_fc_model.AstroFCModel(features, labels, config.hparams,
                                            tf.estimator.ModeKeys.TRAIN)
        model.build()

        # Validate Tensor shapes.
        conv = testing.get_variable_by_name(
            "time_feature_1_hidden/conv1d/kernel")
        self.assertShapeEquals((10, 1, 20), conv)

        fc_1 = testing.get_variable_by_name(
            "time_feature_1_hidden/fully_connected_1/weights")
        self.assertShapeEquals((20, 20), fc_1)

        self.assertItemsEqual(["time_feature_1"],
                              model.time_series_hidden_layers.keys())
        self.assertShapeEquals(
            (None, 20), model.time_series_hidden_layers["time_feature_1"])
        self.assertEqual(len(model.aux_hidden_layers), 0)
        self.assertIs(model.time_series_hidden_layers["time_feature_1"],
                      model.pre_logits_concat)

        # Execute the TensorFlow graph.
        scaffold = tf.compat.v1.train.Scaffold()
        scaffold.finalize()
        with self.test_session() as sess:
            sess.run([scaffold.init_op, scaffold.local_init_op])
            step = sess.run(model.global_step)
            self.assertEqual(0, step)

            # Fetch predictions.
            features = testing.fake_features(feature_spec, batch_size=16)
            labels = testing.fake_labels(config.hparams.output_dim,
                                         batch_size=16)
            feed_dict = input_ops.prepare_feed_dict(model, features, labels)
            predictions = sess.run(model.predictions, feed_dict=feed_dict)
            self.assertShapeEquals((16, 1), predictions)
Beispiel #29
0
def main(argv):
    del argv  # Unused.

    config = configdict.ConfigDict(configurations.get_config(
        FLAGS.config_name))
    config_overrides = json.loads(FLAGS.config_overrides)
    for key in config_overrides:
        if key not in ["dataset", "hparams"]:
            raise ValueError("Unrecognized config override: {}".format(key))
    config.hparams.update(config_overrides.get("hparams", {}))

    # Log configs.
    configs_json = [
        ("config_overrides", config_util.to_json(config_overrides)),
        ("config", config_util.to_json(config)),
    ]
    for config_name, config_json in configs_json:
        tf.logging.info("%s: %s", config_name, config_json)

    # Create the estimator.
    run_config = _create_run_config()
    estimator = estimator_util.create_estimator(
        astrowavenet_model.AstroWaveNet, config.hparams, run_config,
        FLAGS.model_dir, FLAGS.eval_batch_size)

    if FLAGS.schedule in ["train", "train_and_eval"]:
        # Save configs.
        tf.gfile.MakeDirs(FLAGS.model_dir)
        for config_name, config_json in configs_json:
            filename = os.path.join(FLAGS.model_dir,
                                    "{}.json".format(config_name))
            with tf.gfile.Open(filename, "w") as f:
                f.write(config_json)

        train_input_fn = _create_input_fn(tf.estimator.ModeKeys.TRAIN,
                                          config_overrides.get("dataset"))

        train_hooks = []
        if FLAGS.schedule == "train":
            estimator.train(train_input_fn,
                            hooks=train_hooks,
                            max_steps=FLAGS.train_steps)
        else:
            assert FLAGS.schedule == "train_and_eval"

            eval_args = _create_eval_args(config_overrides.get("dataset"))
            for _ in estimator_runner.continuous_train_and_eval(
                    estimator=estimator,
                    train_input_fn=train_input_fn,
                    eval_args=eval_args,
                    local_eval_frequency=FLAGS.local_eval_frequency,
                    train_hooks=train_hooks,
                    train_steps=FLAGS.train_steps):
                # continuous_train_and_eval() yields evaluation metrics after each
                # FLAGS.local_eval_frequency. It also saves and logs them, so we don't
                # do anything here.
                pass

    else:
        assert FLAGS.schedule == "continuous_eval"

        eval_args = _create_eval_args(config_overrides.get("dataset"))
        for _ in estimator_runner.continuous_eval(
                estimator=estimator,
                eval_args=eval_args,
                train_steps=FLAGS.train_steps):
            # continuous_train_and_eval() yields evaluation metrics after each
            # checkpoint. It also saves and logs them, so we don't do anything here.
            pass