コード例 #1
0
 def test_metrics_consistent(self):
     # Tests that the identity metrics used to report in-sample predictions match
     # the behavior of standard metrics.
     g = ops.Graph()
     with g.as_default():
         features = {
             feature_keys.TrainEvalFeatures.TIMES:
             array_ops.zeros((1, 1)),
             feature_keys.TrainEvalFeatures.VALUES:
             array_ops.zeros((1, 1, 1)),
             "ticker":
             array_ops.reshape(
                 math_ops.cast(variables.VariableV1(
                     name="ticker",
                     initial_value=0,
                     dtype=dtypes.int64,
                     collections=[ops.GraphKeys.LOCAL_VARIABLES
                                  ]).count_up_to(10),
                               dtype=dtypes.float32), (1, 1, 1))
         }
         model_fn = ts_head_lib.TimeSeriesRegressionHead(
             model=_TickerModel(),
             state_manager=state_management.PassthroughStateManager(),
             optimizer=train.GradientDescentOptimizer(
                 0.001)).create_estimator_spec
         outputs = model_fn(features=features,
                            labels=None,
                            mode=estimator_lib.ModeKeys.EVAL)
         metric_update_ops = [
             metric[1] for metric in outputs.eval_metric_ops.values()
         ]
         loss_mean, loss_update = metrics.mean(outputs.loss)
         metric_update_ops.append(loss_update)
         with self.cached_session() as sess:
             coordinator = coordinator_lib.Coordinator()
             queue_runner_impl.start_queue_runners(sess, coord=coordinator)
             variables.local_variables_initializer().run()
             sess.run(metric_update_ops)
             loss_evaled, metric_evaled, nested_metric_evaled = sess.run(
                 (loss_mean, outputs.eval_metric_ops["ticker"][0],
                  outputs.eval_metric_ops[
                      feature_keys.FilteringResults.STATE_TUPLE][0][0]))
             # The custom model_utils metrics for in-sample predictions should be in
             # sync with the Estimator's mean metric for model loss.
             self.assertAllClose(0., loss_evaled)
             self.assertAllClose((((0., ), ), ), metric_evaled)
             self.assertAllClose((((0., ), ), ), nested_metric_evaled)
             coordinator.request_stop()
             coordinator.join()
コード例 #2
0
def _stub_model_fn():
    return ts_head_lib.TimeSeriesRegressionHead(
        model=_StubModel(),
        state_manager=state_management.PassthroughStateManager(),
        optimizer=train.AdamOptimizer(0.001)).create_estimator_spec