Esempio n. 1
0
  def testExportClassInEstimator(self):
    export_decorator_a = tf_export.tf_export('TestClassA1')
    export_decorator_a(TestClassA)
    self.assertEquals(('TestClassA1',), TestClassA._tf_api_names)

    export_decorator_b = tf_export.estimator_export(
        'estimator.TestClassB1')
    export_decorator_b(TestClassB)
    self.assertTrue('_tf_api_names' not in TestClassB.__dict__)
    self.assertEquals(('TestClassA1',), TestClassA._tf_api_names)
    self.assertEquals(['TestClassA1'], tf_export.get_v1_names(TestClassA))
    self.assertEquals(['estimator.TestClassB1'],
                      tf_export.get_v1_names(TestClassB))
Esempio n. 2
0
  def testRaisesExceptionIfInvalidSymbolName(self):
    # TensorFlow code is not allowed to export symbols under package
    # tf.estimator
    with self.assertRaises(tf_export.InvalidSymbolNameError):
      tf_export.tf_export('estimator.invalid')

    # All symbols exported by Estimator must be under tf.estimator package.
    with self.assertRaises(tf_export.InvalidSymbolNameError):
      tf_export.estimator_export('invalid')
    with self.assertRaises(tf_export.InvalidSymbolNameError):
      tf_export.estimator_export('Estimator.invalid')
    with self.assertRaises(tf_export.InvalidSymbolNameError):
      tf_export.estimator_export('invalid.estimator')
Esempio n. 3
0
 def testRaisesExceptionIfInvalidV1SymbolName(self):
     with self.assertRaises(tf_export.InvalidSymbolNameError):
         tf_export.tf_export('valid', v1=['estimator.invalid'])
     with self.assertRaises(tf_export.InvalidSymbolNameError):
         tf_export.estimator_export('estimator.valid', v1=['invalid'])
Esempio n. 4
0
 def testRaisesExceptionIfInvalidV1SymbolName(self):
   with self.assertRaises(tf_export.InvalidSymbolNameError):
     tf_export.tf_export('valid', v1=['estimator.invalid'])
   with self.assertRaises(tf_export.InvalidSymbolNameError):
     tf_export.estimator_export('estimator.valid', v1=['invalid'])
from tensorflow.python.training.basic_session_run_hooks import CheckpointSaverHook
from tensorflow.python.training.basic_session_run_hooks import CheckpointSaverListener
from tensorflow.python.training.basic_session_run_hooks import FeedFnHook
from tensorflow.python.training.basic_session_run_hooks import FinalOpsHook
from tensorflow.python.training.basic_session_run_hooks import GlobalStepWaiterHook
from tensorflow.python.training.basic_session_run_hooks import LoggingTensorHook
from tensorflow.python.training.basic_session_run_hooks import NanLossDuringTrainingError
from tensorflow.python.training.basic_session_run_hooks import NanTensorHook
from tensorflow.python.training.basic_session_run_hooks import ProfilerHook
from tensorflow.python.training.basic_session_run_hooks import SecondOrStepTimer
from tensorflow.python.training.basic_session_run_hooks import StepCounterHook
from tensorflow.python.training.basic_session_run_hooks import StopAtStepHook
from tensorflow.python.training.basic_session_run_hooks import SummarySaverHook
from tensorflow.python.util.tf_export import estimator_export

estimator_export("estimator.SecondOrStepTimer")(SecondOrStepTimer)
estimator_export("estimator.LoggingTensorHook")(LoggingTensorHook)
estimator_export("estimator.StopAtStepHook")(StopAtStepHook)
estimator_export("estimator.CheckpointSaverListener")(CheckpointSaverListener)
estimator_export("estimator.CheckpointSaverHook")(CheckpointSaverHook)
estimator_export("estimator.StepCounterHook")(StepCounterHook)
estimator_export("estimator.NanLossDuringTrainingError")(
    NanLossDuringTrainingError)
estimator_export("estimator.NanTensorHook")(NanTensorHook)
estimator_export("estimator.SummarySaverHook")(SummarySaverHook)
estimator_export("estimator.GlobalStepWaiterHook")(GlobalStepWaiterHook)
estimator_export("estimator.FinalOpsHook")(FinalOpsHook)
estimator_export("estimator.FeedFnHook")(FeedFnHook)
estimator_export("estimator.ProfilerHook")(ProfilerHook)
Esempio n. 6
0
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classes for different types of export output."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=unused-import
from tensorflow.python.saved_model.model_utils.export_output import _SupervisedOutput
from tensorflow.python.saved_model.model_utils.export_output import ClassificationOutput
from tensorflow.python.saved_model.model_utils.export_output import EvalOutput
from tensorflow.python.saved_model.model_utils.export_output import ExportOutput
from tensorflow.python.saved_model.model_utils.export_output import PredictOutput
from tensorflow.python.saved_model.model_utils.export_output import RegressionOutput
from tensorflow.python.saved_model.model_utils.export_output import TrainOutput
# pylint: enable=unused-import
from tensorflow.python.util.tf_export import estimator_export

estimator_export('estimator.export.ExportOutput')(ExportOutput)
estimator_export('estimator.export.ClassificationOutput')(ClassificationOutput)
estimator_export('estimator.export.RegressionOutput')(RegressionOutput)
estimator_export('estimator.export.PredictOutput')(PredictOutput)
Esempio n. 7
0
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Exporting ModeKeys to tf.estimator namespace."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.saved_model.model_utils.mode_keys import EstimatorModeKeys as ModeKeys
from tensorflow.python.util.tf_export import estimator_export

estimator_export('estimator.ModeKeys')(ModeKeys)
Esempio n. 8
0
  call hooks.after_create_session()
  while not stop is requested:
    call hooks.before_run()
    try:
      results = sess.run(merged_fetches, feed_dict=merged_feeds)
    except (errors.OutOfRangeError, StopIteration):
      break
    call hooks.after_run()
  call hooks.end()
  sess.close()

Note that if sess.run() raises OutOfRangeError or StopIteration then
hooks.after_run() will not be called but hooks.end() will still be called.
If sess.run() raises any other exception then neither hooks.after_run() nor
hooks.end() will be called.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.training.session_run_hook import SessionRunArgs
from tensorflow.python.training.session_run_hook import SessionRunContext
from tensorflow.python.training.session_run_hook import SessionRunHook
from tensorflow.python.training.session_run_hook import SessionRunValues
from tensorflow.python.util.tf_export import estimator_export

estimator_export("estimator.SessionRunHook")(SessionRunHook)
estimator_export("estimator.SessionRunArgs")(SessionRunArgs)
estimator_export("estimator.SessionRunContext")(SessionRunContext)
estimator_export("estimator.SessionRunValues")(SessionRunValues)