Ejemplo n.º 1
0
  def __init__(self, model, state_manager=None, optimizer=None, model_dir=None,
               config=None, head_type=ts_head_lib.TimeSeriesRegressionHead):
    """Initialize the Estimator.

    Args:
      model: The time series model to wrap (inheriting from TimeSeriesModel).
      state_manager: The state manager to use, or (by default)
          PassthroughStateManager if none is needed.
      optimizer: The optimization algorithm to use when training, inheriting
          from tf.train.Optimizer. Defaults to Adam with step size 0.02.
      model_dir: See `Estimator`.
      config: See `Estimator`.
      head_type: The kind of head to use for the model (inheriting from
          `TimeSeriesRegressionHead`).
    """
    input_statistics_generator = math_utils.InputStatisticsFromMiniBatch(
        dtype=model.dtype, num_features=model.num_features)
    if state_manager is None:
      if isinstance(model, ar_model.ARModel):
        state_manager = state_management.FilteringOnlyStateManager()
      else:
        state_manager = state_management.PassthroughStateManager()
    if optimizer is None:
      optimizer = train.AdamOptimizer(0.02)
    self._model = model
    ts_regression_head = head_type(
        model=model, state_manager=state_manager, optimizer=optimizer,
        input_statistics_generator=input_statistics_generator)
    model_fn = ts_regression_head.create_estimator_spec
    super(TimeSeriesRegressor, self).__init__(
        model_fn=model_fn,
        model_dir=model_dir,
        config=config)
Ejemplo n.º 2
0
 def test_queue(self):
     for dtype in [tf.dtypes.float32, tf.dtypes.float64]:
         for num_features in [1, 2, 3]:
             self._input_statistics_test_template(
                 math_utils.InputStatisticsFromMiniBatch(
                     num_features=num_features, dtype=dtype),
                 num_features=num_features,
                 dtype=dtype,
                 warmup_iterations=1000,
                 rtol=0.1)