Ejemplo n.º 1
0
  def __init__(self, model, state_manager=None, optimizer=None, model_dir=None,
               config=None, head_type=ts_head_lib.TimeSeriesRegressionHead):
    """Initialize the Estimator.

    Args:
      model: The time series model to wrap (inheriting from TimeSeriesModel).
      state_manager: The state manager to use, or (by default)
          PassthroughStateManager if none is needed.
      optimizer: The optimization algorithm to use when training, inheriting
          from tf.train.Optimizer. Defaults to Adam with step size 0.02.
      model_dir: See `Estimator`.
      config: See `Estimator`.
      head_type: The kind of head to use for the model (inheriting from
          `TimeSeriesRegressionHead`).
    """
    input_statistics_generator = math_utils.InputStatisticsFromMiniBatch(
        dtype=model.dtype, num_features=model.num_features)
    if state_manager is None:
      if isinstance(model, ar_model.ARModel):
        state_manager = state_management.FilteringOnlyStateManager()
      else:
        state_manager = state_management.PassthroughStateManager()
    if optimizer is None:
      optimizer = train.AdamOptimizer(0.02)
    self._model = model
    ts_regression_head = head_type(
        model=model, state_manager=state_manager, optimizer=optimizer,
        input_statistics_generator=input_statistics_generator)
    model_fn = ts_regression_head.create_estimator_spec
    super(TimeSeriesRegressor, self).__init__(
        model_fn=model_fn,
        model_dir=model_dir,
        config=config)
Ejemplo n.º 2
0
def parameter_recovery_dry_run(
    generate_fn, generative_model, seed,
    learning_rate=0.1,
    train_input_fn_type=input_pipeline.WholeDatasetInputFn,
    train_state_manager=state_management.PassthroughStateManager()):
  """Test that a generative model can train on generated data.

  Args:
    generate_fn: A function taking a model and returning a
        `input_pipeline.TimeSeriesReader` object and a dictionary mapping
        parameters to their values. model.initialize_graph() will have been
        called on the model before it is passed to this function.
    generative_model: A timeseries.model.TimeSeriesModel instance to test.
    seed: Same as for TimeSeriesModel.unconditional_generate().
    learning_rate: Step size for optimization.
    train_input_fn_type: The type of `TimeSeriesInputFn` to use when training
        (likely `WholeDatasetInputFn` or `RandomWindowInputFn`). If None, use
        `WholeDatasetInputFn`.
    train_state_manager: The state manager to use when training (likely
        `PassthroughStateManager` or `ChainingStateManager`). If None, use
        `PassthroughStateManager`.
  """
  _train_on_generated_data(
      generate_fn=generate_fn, generative_model=generative_model,
      seed=seed, learning_rate=learning_rate,
      train_input_fn_type=train_input_fn_type,
      train_state_manager=train_state_manager,
      train_iterations=2)
Ejemplo n.º 3
0
 def test_metrics_consistent(self):
     # Tests that the identity metrics used to report in-sample predictions match
     # the behavior of standard metrics.
     g = tf.Graph()
     with g.as_default():
         features = {
             feature_keys.TrainEvalFeatures.TIMES:
             tf.zeros((1, 1)),
             feature_keys.TrainEvalFeatures.VALUES:
             tf.zeros((1, 1, 1)),
             "ticker":
             tf.reshape(
                 tf.cast(tf.compat.v1.Variable(
                     name="ticker",
                     initial_value=0,
                     dtype=tf.dtypes.int64,
                     collections=[tf.compat.v1.GraphKeys.LOCAL_VARIABLES
                                  ]).count_up_to(10),
                         dtype=tf.dtypes.float32), (1, 1, 1))
         }
         model_fn = ts_head_lib.TimeSeriesRegressionHead(
             model=_TickerModel(),
             state_manager=state_management.PassthroughStateManager(),
             optimizer=tf.compat.v1.train.GradientDescentOptimizer(
                 0.001)).create_estimator_spec
         outputs = model_fn(features=features,
                            labels=None,
                            mode=estimator_lib.ModeKeys.EVAL)
         metric_update_ops = [
             metric[1] for metric in outputs.eval_metric_ops.values()
         ]
         loss_mean, loss_update = tf.compat.v1.metrics.mean(outputs.loss)
         metric_update_ops.append(loss_update)
         with self.cached_session() as sess:
             coordinator = tf.train.Coordinator()
             tf.compat.v1.train.queue_runner.start_queue_runners(
                 sess, coord=coordinator)
             tf.compat.v1.initializers.local_variables().run()
             sess.run(metric_update_ops)
             loss_evaled, metric_evaled, nested_metric_evaled = sess.run(
                 (loss_mean, outputs.eval_metric_ops["ticker"][0],
                  outputs.eval_metric_ops[
                      feature_keys.FilteringResults.STATE_TUPLE][0][0]))
             # The custom model_utils metrics for in-sample predictions should be in
             # sync with the Estimator's mean metric for model loss.
             self.assertAllClose(0., loss_evaled)
             self.assertAllClose((((0., ), ), ), metric_evaled)
             self.assertAllClose((((0., ), ), ), nested_metric_evaled)
             coordinator.request_stop()
             coordinator.join()
Ejemplo n.º 4
0
def _stub_model_fn():
  return ts_head_lib.TimeSeriesRegressionHead(
      model=_StubModel(),
      state_manager=state_management.PassthroughStateManager(),
      optimizer=tf.compat.v1.train.AdamOptimizer(0.001)).create_estimator_spec
Ejemplo n.º 5
0
def _train_on_generated_data(
    generate_fn, generative_model, train_iterations, seed,
    learning_rate=0.1, ignore_params_fn=lambda _: (),
    derived_param_test_fn=lambda _: (),
    train_input_fn_type=input_pipeline.WholeDatasetInputFn,
    train_state_manager=state_management.PassthroughStateManager()):
  """The training portion of parameter recovery tests."""
  random_seed.set_random_seed(seed)
  generate_graph = ops.Graph()
  with generate_graph.as_default():
    with session.Session(graph=generate_graph):
      generative_model.initialize_graph()
      time_series_reader, true_parameters = generate_fn(generative_model)
      true_parameters = {
          tensor.name: value for tensor, value in true_parameters.items()}
  eval_input_fn = input_pipeline.WholeDatasetInputFn(time_series_reader)
  eval_state_manager = state_management.PassthroughStateManager()
  true_parameter_eval_graph = ops.Graph()
  with true_parameter_eval_graph.as_default():
    generative_model.initialize_graph()
    ignore_params = ignore_params_fn(generative_model)
    feature_dict, _ = eval_input_fn()
    eval_state_manager.initialize_graph(generative_model)
    feature_dict[TrainEvalFeatures.VALUES] = math_ops.cast(
        feature_dict[TrainEvalFeatures.VALUES], generative_model.dtype)
    model_outputs = eval_state_manager.define_loss(
        model=generative_model,
        features=feature_dict,
        mode=estimator_lib.ModeKeys.EVAL)
    with session.Session(graph=true_parameter_eval_graph) as sess:
      variables.global_variables_initializer().run()
      coordinator = coordinator_lib.Coordinator()
      queue_runner_impl.start_queue_runners(sess, coord=coordinator)
      true_param_loss = model_outputs.loss.eval(feed_dict=true_parameters)
      true_transformed_params = {
          param: param.eval(feed_dict=true_parameters)
          for param in derived_param_test_fn(generative_model)}
      coordinator.request_stop()
      coordinator.join()

  saving_hook = _SavingTensorHook(
      tensors=true_parameters.keys(),
      every_n_iter=train_iterations - 1)

  class _RunConfig(estimator_lib.RunConfig):

    @property
    def tf_random_seed(self):
      return seed

  estimator = estimators.TimeSeriesRegressor(
      model=generative_model,
      config=_RunConfig(),
      state_manager=train_state_manager,
      optimizer=adam.AdamOptimizer(learning_rate))
  train_input_fn = train_input_fn_type(time_series_reader=time_series_reader)
  trained_loss = (estimator.train(
      input_fn=train_input_fn,
      max_steps=train_iterations,
      hooks=[saving_hook]).evaluate(
          input_fn=eval_input_fn, steps=1))["loss"]
  logging.info("Final trained loss: %f", trained_loss)
  logging.info("True parameter loss: %f", true_param_loss)
  return (ignore_params, true_parameters, true_transformed_params,
          trained_loss, true_param_loss, saving_hook,
          true_parameter_eval_graph)
Ejemplo n.º 6
0
def test_parameter_recovery(
    generate_fn, generative_model, train_iterations, test_case, seed,
    learning_rate=0.1, rtol=0.2, atol=0.1, train_loss_tolerance_coeff=0.99,
    ignore_params_fn=lambda _: (),
    derived_param_test_fn=lambda _: (),
    train_input_fn_type=input_pipeline.WholeDatasetInputFn,
    train_state_manager=state_management.PassthroughStateManager()):
  """Test that a generative model fits generated data.

  Args:
    generate_fn: A function taking a model and returning a `TimeSeriesReader`
        object and dictionary mapping parameters to their
        values. model.initialize_graph() will have been called on the model
        before it is passed to this function.
    generative_model: A timeseries.model.TimeSeriesModel instance to test.
    train_iterations: Number of training steps.
    test_case: A tf.test.TestCase to run assertions on.
    seed: Same as for TimeSeriesModel.unconditional_generate().
    learning_rate: Step size for optimization.
    rtol: Relative tolerance for tests.
    atol: Absolute tolerance for tests.
    train_loss_tolerance_coeff: Trained loss times this value must be less
        than the loss evaluated using the generated parameters.
    ignore_params_fn: Function mapping from a Model to a list of parameters
        which are not tested for accurate recovery.
    derived_param_test_fn: Function returning a list of derived parameters
        (Tensors) which are checked for accurate recovery (comparing the value
        evaluated with trained parameters to the value under the true
        parameters).

        As an example, for VARMA, in addition to checking AR and MA parameters,
        this function can be used to also check lagged covariance. See
        varma_ssm.py for details.
    train_input_fn_type: The `TimeSeriesInputFn` type to use when training
        (likely `WholeDatasetInputFn` or `RandomWindowInputFn`). If None, use
        `WholeDatasetInputFn`.
    train_state_manager: The state manager to use when training (likely
        `PassthroughStateManager` or `ChainingStateManager`). If None, use
        `PassthroughStateManager`.
  """
  (ignore_params, true_parameters, true_transformed_params,
   trained_loss, true_param_loss, saving_hook, true_parameter_eval_graph
  ) = _train_on_generated_data(
      generate_fn=generate_fn, generative_model=generative_model,
      train_iterations=train_iterations, seed=seed, learning_rate=learning_rate,
      ignore_params_fn=ignore_params_fn,
      derived_param_test_fn=derived_param_test_fn,
      train_input_fn_type=train_input_fn_type,
      train_state_manager=train_state_manager)
  trained_parameter_substitutions = {}
  for param in true_parameters.keys():
    evaled_value = saving_hook.tensor_values[param]
    trained_parameter_substitutions[param] = evaled_value
    true_value = true_parameters[param]
    logging.info("True %s: %s, learned: %s",
                 param, true_value, evaled_value)
  with session.Session(graph=true_parameter_eval_graph):
    for transformed_param, true_value in true_transformed_params.items():
      trained_value = transformed_param.eval(
          feed_dict=trained_parameter_substitutions)
      logging.info("True %s [transformed parameter]: %s, learned: %s",
                   transformed_param, true_value, trained_value)
      test_case.assertAllClose(true_value, trained_value,
                               rtol=rtol, atol=atol)

  if ignore_params is None:
    ignore_params = []
  else:
    ignore_params = nest.flatten(ignore_params)
  ignore_params = [tensor.name for tensor in ignore_params]
  if trained_loss > 0:
    test_case.assertLess(trained_loss * train_loss_tolerance_coeff,
                         true_param_loss)
  else:
    test_case.assertLess(trained_loss / train_loss_tolerance_coeff,
                         true_param_loss)
  for param in true_parameters.keys():
    if param in ignore_params:
      continue
    evaled_value = saving_hook.tensor_values[param]
    true_value = true_parameters[param]
    test_case.assertAllClose(true_value, evaled_value,
                             rtol=rtol, atol=atol)