예제 #1
0
  def __init__(self, model, state_manager=None, optimizer=None, model_dir=None,
               config=None, head_type=ts_head_lib.TimeSeriesRegressionHead):
    """Initialize the Estimator.

    Args:
      model: The time series model to wrap (inheriting from TimeSeriesModel).
      state_manager: The state manager to use, or (by default)
          PassthroughStateManager if none is needed.
      optimizer: The optimization algorithm to use when training, inheriting
          from tf.train.Optimizer. Defaults to Adam with step size 0.02.
      model_dir: See `Estimator`.
      config: See `Estimator`.
      head_type: The kind of head to use for the model (inheriting from
          `TimeSeriesRegressionHead`).
    """
    input_statistics_generator = math_utils.InputStatisticsFromMiniBatch(
        dtype=model.dtype, num_features=model.num_features)
    if state_manager is None:
      if isinstance(model, ar_model.ARModel):
        state_manager = state_management.FilteringOnlyStateManager()
      else:
        state_manager = state_management.PassthroughStateManager()
    if optimizer is None:
      optimizer = train.AdamOptimizer(0.02)
    self._model = model
    ts_regression_head = head_type(
        model=model, state_manager=state_manager, optimizer=optimizer,
        input_statistics_generator=input_statistics_generator)
    model_fn = ts_regression_head.create_estimator_spec
    super(TimeSeriesRegressor, self).__init__(
        model_fn=model_fn,
        model_dir=model_dir,
        config=config)
예제 #2
0
  def __init__(self, model, state_manager=None, optimizer=None, model_dir=None,
               config=None):
    """Initialize the Estimator.

    Args:
      model: The time series model to wrap (inheriting from TimeSeriesModel).
      state_manager: The state manager to use, or (by default)
          PassthroughStateManager if none is needed.
      optimizer: The optimization algorithm to use when training, inheriting
          from tf.train.Optimizer. Defaults to Adam with step size 0.02.
      model_dir: See `Estimator`.
      config: See `Estimator`.
    """
    input_statistics_generator = math_utils.InputStatisticsFromMiniBatch(
        dtype=model.dtype, num_features=model.num_features)
    if state_manager is None:
      state_manager = state_management.PassthroughStateManager()
    if optimizer is None:
      optimizer = train.AdamOptimizer(0.02)
    self._model = model
    model_fn = model_utils.make_model_fn(
        model, state_manager, optimizer,
        input_statistics_generator=input_statistics_generator)
    super(_TimeSeriesRegressor, self).__init__(
        model_fn=model_fn,
        model_dir=model_dir,
        config=config)
예제 #3
0
 def test_queue(self):
   for dtype in [dtypes.float32, dtypes.float64]:
     for num_features in [1, 2, 3]:
       self._input_statistics_test_template(
           math_utils.InputStatisticsFromMiniBatch(
               num_features=num_features, dtype=dtype),
           num_features=num_features,
           dtype=dtype,
           give_full_data=False,
           warmup_iterations=1000,
           rtol=0.1)