def testExportOutputsNoDict(self): with ops.Graph().as_default(), self.test_session(): predictions = {'loss': constant_op.constant(1.)} classes = constant_op.constant('hello') with self.assertRaisesRegexp(TypeError, 'export_outputs must be dict'): model_fn.EstimatorSpec( mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs=export.ClassificationOutput( classes=classes))
def _model_fn_for_export_tests(features, labels, mode): _, _ = features, labels variables.Variable(1., name='weight') scores = constant_op.constant([3.]) classes = constant_op.constant(['wumpus']) return model_fn_lib.EstimatorSpec( mode, predictions=constant_op.constant(10.), loss=constant_op.constant(1.), train_op=constant_op.constant(2.), export_outputs={ 'test': export.ClassificationOutput(scores, classes)})
def _model_fn_scaffold(features, labels, mode): _, _ = features, labels variables.Variable(1., name='weight') real_saver = saver.Saver() self.mock_saver = test.mock.Mock( wraps=real_saver, saver_def=real_saver.saver_def) scores = constant_op.constant([3.]) return model_fn_lib.EstimatorSpec( mode=mode, predictions=constant_op.constant([[1.]]), loss=constant_op.constant(0.), train_op=constant_op.constant(0.), scaffold=training.Scaffold(saver=self.mock_saver), export_outputs={'test': export.ClassificationOutput(scores)})
def testAllArgumentsSet(self): """Tests that no errors are raised when all arguments are set.""" with ops.Graph().as_default(), self.test_session(): loss = constant_op.constant(1.) predictions = {'loss': loss} classes = constant_op.constant('hello') model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=loss, train_op=control_flow_ops.no_op(), eval_metric_ops={'loss': (control_flow_ops.no_op(), loss)}, export_outputs={ 'head_name': export.ClassificationOutput(classes=classes) }, training_chief_hooks=[_FakeHook()], training_hooks=[_FakeHook()], scaffold=monitored_session.Scaffold())
def _model_fn_scaffold(features, labels, mode): _, _ = features, labels my_int = variables.Variable(1, name='my_int', collections=[ops.GraphKeys.LOCAL_VARIABLES]) scores = constant_op.constant([3.]) with ops.control_dependencies( [variables.local_variables_initializer(), data_flow_ops.tables_initializer()]): assign_op = state_ops.assign(my_int, 12345) # local_initSop must be an Operation, not a Tensor. custom_local_init_op = control_flow_ops.group(assign_op) return model_fn_lib.EstimatorSpec( mode=mode, predictions=constant_op.constant([[1.]]), loss=constant_op.constant(0.), train_op=constant_op.constant(0.), scaffold=training.Scaffold(local_init_op=custom_local_init_op), export_outputs={'test': export.ClassificationOutput(scores)})
def testExportOutputsMultiheadWithDefault(self): with ops.Graph().as_default(), self.test_session(): predictions = {'loss': constant_op.constant(1.)} output_1 = constant_op.constant([1.]) output_2 = constant_op.constant(['2']) output_3 = constant_op.constant(['3']) export_outputs = { signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: export.RegressionOutput(value=output_1), 'head-2': export.ClassificationOutput(classes=output_2), 'head-3': export.PredictOutput(outputs={'some_output_3': output_3}) } estimator_spec = model_fn.EstimatorSpec( mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs=export_outputs) self.assertEqual(export_outputs, estimator_spec.export_outputs)
def testExportOutputsMultiheadMissingDefault(self): with ops.Graph().as_default(), self.test_session(): predictions = {'loss': constant_op.constant(1.)} output_1 = constant_op.constant([1.]) output_2 = constant_op.constant(['2']) output_3 = constant_op.constant(['3']) export_outputs = { 'head-1': export.RegressionOutput(value=output_1), 'head-2': export.ClassificationOutput(classes=output_2), 'head-3': export.PredictOutput(outputs={'some_output_3': output_3}) } with self.assertRaisesRegexp( ValueError, 'Multiple export_outputs were provided, but none of them is ' 'specified as the default. Do this by naming one of them with ' 'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.'): model_fn.EstimatorSpec(mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs=export_outputs)