def test_parse_input_fn_result_invalid(self): def _input_fn(): features = np.expand_dims(np.arange(100), 0) labels = np.expand_dims(np.arange(100, 200), 0) return dataset_ops.Dataset.from_tensor_slices((features, labels, labels)) with self.assertRaisesRegexp(ValueError, 'input_fn should return'): util.parse_input_fn_result(_input_fn())
def test_parse_input_fn_result_invalid(self): def _input_fn(): features = np.expand_dims(np.arange(100), 0) labels = np.expand_dims(np.arange(100, 200), 0) return dataset_ops.Dataset.from_tensor_slices( (features, labels, labels)) with self.assertRaisesRegexp(ValueError, 'input_fn should return'): util.parse_input_fn_result(_input_fn())
def build_supervised_input_receiver_fn_from_input_fn(input_fn, **input_fn_args): """Get a function that returns a SupervisedInputReceiver matching an input_fn. Note that this function calls the input_fn in a local graph in order to extract features and labels. Placeholders are then created from those features and labels in the default graph. Args: input_fn: An Estimator input_fn, which is a function that returns one of: * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a tuple (features, labels) with same constraints as below. * A tuple (features, labels): Where `features` is a `Tensor` or a dictionary of string feature name to `Tensor` and `labels` is a `Tensor` or a dictionary of string label name to `Tensor`. Both `features` and `labels` are consumed by `model_fn`. They should satisfy the expectation of `model_fn` from inputs. **input_fn_args: set of kwargs to be passed to the input_fn. Note that these will not be checked or validated here, and any errors raised by the input_fn will be thrown to the top. Returns: A function taking no arguments that, when called, returns a SupervisedInputReceiver. This function can be passed in as part of the input_receiver_map when exporting SavedModels from Estimator with multiple modes. """ # Wrap the input_fn call in a graph to prevent sullying the default namespace with ops.Graph().as_default(): result = input_fn(**input_fn_args) features, labels, _ = util.parse_input_fn_result(result) # Placeholders are created back in the default graph. return build_raw_supervised_input_receiver_fn(features, labels)
def test_parse_input_fn_result_features_only(self): def _input_fn(): return constant_op.constant(np.arange(100)) features, labels, hooks = util.parse_input_fn_result(_input_fn()) with self.cached_session() as sess: vals = sess.run([features]) self.assertAllEqual(vals[0], np.arange(100)) self.assertEqual(labels, None) self.assertEqual(hooks, [])
def test_parse_input_fn_result_features_only(self): def _input_fn(): return constant_op.constant(np.arange(100)) features, labels, hooks = util.parse_input_fn_result(_input_fn()) with self.test_session() as sess: vals = sess.run([features]) self.assertAllEqual(vals[0], np.arange(100)) self.assertEqual(labels, None) self.assertEqual(hooks, [])
def test_parse_input_fn_result_features_only_dataset(self): def _input_fn(): features = np.expand_dims(np.arange(100), 0) return dataset_ops.Dataset.from_tensor_slices(features) features, labels, hooks = util.parse_input_fn_result(_input_fn()) with training.MonitoredSession(hooks=hooks) as sess: vals = sess.run([features]) self.assertAllEqual(vals[0], np.arange(100)) self.assertEqual(labels, None) self.assertIsInstance(hooks[0], util._DatasetInitializerHook)
def test_parse_input_fn_result_tuple(self): def _input_fn(): features = constant_op.constant(np.arange(100)) labels = constant_op.constant(np.arange(100, 200)) return features, labels features, labels, hooks = util.parse_input_fn_result(_input_fn()) with self.cached_session() as sess: vals = sess.run([features, labels]) self.assertAllEqual(vals[0], np.arange(100)) self.assertAllEqual(vals[1], np.arange(100, 200)) self.assertEqual(hooks, [])
def test_parse_input_fn_result_tuple(self): def _input_fn(): features = constant_op.constant(np.arange(100)) labels = constant_op.constant(np.arange(100, 200)) return features, labels features, labels, hooks = util.parse_input_fn_result(_input_fn()) with self.test_session() as sess: vals = sess.run([features, labels]) self.assertAllEqual(vals[0], np.arange(100)) self.assertAllEqual(vals[1], np.arange(100, 200)) self.assertEqual(hooks, [])