def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension, n_classes, batch_size): feature_columns = [ feature_column.numeric_column('x', shape=(input_dimension,))] est = dnn.DNNClassifierV2( hidden_units=(2, 2), feature_columns=feature_columns, n_classes=n_classes, model_dir=self._model_dir) # TRAIN num_steps = 10 est.train(train_input_fn, steps=num_steps) # EVALUATE scores = est.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) self.assertIn('loss', six.iterkeys(scores)) # PREDICT predicted_proba = np.array([ x[prediction_keys.PredictionKeys.PROBABILITIES] for x in est.predict(predict_input_fn) ]) self.assertAllEqual((batch_size, n_classes), predicted_proba.shape) # EXPORT feature_spec = feature_column.make_parse_example_spec(feature_columns) serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( feature_spec) export_dir = est.export_saved_model(tempfile.mkdtemp(), serving_input_receiver_fn) self.assertTrue(gfile.Exists(export_dir))
def _dnn_classifier_fn(*args, **kwargs): return dnn.DNNClassifierV2(*args, **kwargs)