Example #1
0
def filter_continuation(continue_from, signatures, session, features):
    """Perform filtering using an exported saved model.

  Filtering refers to updating model state based on new observations.
  Predictions based on the returned model state will be conditioned on these
  observations.

  Args:
    continue_from: A dictionary containing the results of either an Estimator's
      evaluate method or a previous filter_continuation. Used to determine the
      model state to start filtering from.
    signatures: The `MetaGraphDef` protocol buffer returned from
      `tf.saved_model.loader.load`. Used to determine the names of Tensors to
      feed and fetch. Must be from the same model as `continue_from`.
    session: The session to use. The session's graph must be the one into which
      `tf.saved_model.loader.load` loaded the model.
    features: A dictionary mapping keys to Numpy arrays, with several possible
      shapes (requires keys `FilteringFeatures.TIMES` and
      `FilteringFeatures.VALUES`):
        Single example; `TIMES` is a scalar and `VALUES` is either a scalar or a
          vector of length [number of features].
        Sequence; `TIMES` is a vector of shape [series length], `VALUES` either
          has shape [series length] (univariate) or [series length x number of
          features] (multivariate).
        Batch of sequences; `TIMES` is a vector of shape [batch size x series
          length], `VALUES` has shape [batch size x series length] or [batch
          size x series length x number of features].
      In any case, `VALUES` and any exogenous features must have their shapes
      prefixed by the shape of the value corresponding to the `TIMES` key.
  Returns:
    A dictionary containing model state updated to account for the observations
    in `features`.
  """
    filter_signature = signatures.signature_def[
        _feature_keys.SavedModelLabels.FILTER]
    features = _input_pipeline._canonicalize_numpy_data(  # pylint: disable=protected-access
        data=features,
        require_single_batch=False)
    output_tensors_by_name, feed_dict = _colate_features_to_feeds_and_fetches(
        continue_from=continue_from,
        signature=filter_signature,
        features=features,
        graph=session.graph)
    output = session.run(output_tensors_by_name, feed_dict=feed_dict)
    # Make it easier to chain filter -> predict by keeping track of the current
    # time.
    output[_feature_keys.FilteringResults.TIMES] = features[
        _feature_keys.FilteringFeatures.TIMES]
    return output
def filter_continuation(continue_from, signatures, session, features):
  """Perform filtering using an exported saved model.

  Filtering refers to updating model state based on new observations.
  Predictions based on the returned model state will be conditioned on these
  observations.

  Args:
    continue_from: A dictionary containing the results of either an Estimator's
      evaluate method or a previous filter_continuation. Used to determine the
      model state to start filtering from.
    signatures: The `MetaGraphDef` protocol buffer returned from
      `tf.saved_model.loader.load`. Used to determine the names of Tensors to
      feed and fetch. Must be from the same model as `continue_from`.
    session: The session to use. The session's graph must be the one into which
      `tf.saved_model.loader.load` loaded the model.
    features: A dictionary mapping keys to Numpy arrays, with several possible
      shapes (requires keys `FilteringFeatures.TIMES` and
      `FilteringFeatures.VALUES`):
        Single example; `TIMES` is a scalar and `VALUES` is either a scalar or a
          vector of length [number of features].
        Sequence; `TIMES` is a vector of shape [series length], `VALUES` either
          has shape [series length] (univariate) or [series length x number of
          features] (multivariate).
        Batch of sequences; `TIMES` is a vector of shape [batch size x series
          length], `VALUES` has shape [batch size x series length] or [batch
          size x series length x number of features].
      In any case, `VALUES` and any exogenous features must have their shapes
      prefixed by the shape of the value corresponding to the `TIMES` key.
  Returns:
    A dictionary containing model state updated to account for the observations
    in `features`.
  """
  filter_signature = signatures.signature_def[
      _feature_keys.SavedModelLabels.FILTER]
  features = _input_pipeline._canonicalize_numpy_data(  # pylint: disable=protected-access
      data=features,
      require_single_batch=False)
  output_tensors_by_name, feed_dict = _colate_features_to_feeds_and_fetches(
      continue_from=continue_from,
      signature=filter_signature,
      features=features,
      graph=session.graph)
  output = session.run(output_tensors_by_name, feed_dict=feed_dict)
  # Make it easier to chain filter -> predict by keeping track of the current
  # time.
  output[_feature_keys.FilteringResults.TIMES] = features[
      _feature_keys.FilteringFeatures.TIMES]
  return output