def test_write_graph_and_checkpoint(self): tmp_dir = self.get_temp_dir() trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt') self._save_checkpoint_from_mock_model(trained_checkpoint_prefix, use_moving_averages=False) output_directory = os.path.join(tmp_dir, 'output') model_path = os.path.join(output_directory, 'model.ckpt') meta_graph_path = model_path + '.meta' tf.gfile.MakeDirs(output_directory) with mock.patch.object( model_builder, 'build', autospec=True) as mock_builder: mock_builder.return_value = FakeModel( add_detection_keypoints=True, add_detection_masks=True) pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config.eval_config.use_moving_averages = False detection_model = model_builder.build(pipeline_config.model, is_training=False) exporter._build_detection_graph( input_type='tf_example', detection_model=detection_model, input_shape=None, output_collection_name='inference_op', graph_hook_fn=None) saver = tf.train.Saver() input_saver_def = saver.as_saver_def() exporter.write_graph_and_checkpoint( inference_graph_def=tf.get_default_graph().as_graph_def(), model_path=model_path, input_saver_def=input_saver_def, trained_checkpoint_prefix=trained_checkpoint_prefix) tf_example_np = np.hstack([self._create_tf_example( np.ones((4, 4, 3)).astype(np.uint8))] * 2) with tf.Graph().as_default() as od_graph: with self.test_session(graph=od_graph) as sess: new_saver = tf.train.import_meta_graph(meta_graph_path) new_saver.restore(sess, model_path) tf_example = od_graph.get_tensor_by_name('tf_example:0') boxes = od_graph.get_tensor_by_name('detection_boxes:0') scores = od_graph.get_tensor_by_name('detection_scores:0') classes = od_graph.get_tensor_by_name('detection_classes:0') keypoints = od_graph.get_tensor_by_name('detection_keypoints:0') masks = od_graph.get_tensor_by_name('detection_masks:0') num_detections = od_graph.get_tensor_by_name('num_detections:0') (boxes_np, scores_np, classes_np, keypoints_np, masks_np, num_detections_np) = sess.run( [boxes, scores, classes, keypoints, masks, num_detections], feed_dict={tf_example: tf_example_np}) self.assertAllClose(boxes_np, [[[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 0.8, 0.8]], [[0.5, 0.5, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]]]) self.assertAllClose(scores_np, [[0.7, 0.6], [0.9, 0.0]]) self.assertAllClose(classes_np, [[1, 2], [2, 1]]) self.assertAllClose(keypoints_np, np.arange(48).reshape([2, 2, 6, 2])) self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4])) self.assertAllClose(num_detections_np, [2, 1])
def _export_inference_graph(input_type, detection_model, use_moving_averages, trained_checkpoint_prefix, output_directory, additional_output_tensor_names=None, input_shape=None, output_collection_name='inference_op', graph_hook_fn=None, write_inference_graph=False, temp_checkpoint_prefix=''): """Export helper.""" tf.gfile.MakeDirs(output_directory) saved_model_path = os.path.join(output_directory, 'saved_model', '00001') model_path = os.path.join(output_directory, 'model.ckpt') outputs, placeholder_tensor = build_detection_graph( input_type=input_type, detection_model=detection_model, input_shape=input_shape, output_collection_name=output_collection_name, graph_hook_fn=graph_hook_fn) # OpenTPOD: popping unnecessary outputs for object detection inference. # see # https://github.com/tensorflow/models/blob/master/research/object_detection/core/standard_fields.py outputs.pop(fields.DetectionResultFields.detection_multiclass_scores, None) outputs.pop(fields.DetectionResultFields.detection_features, None) outputs.pop(fields.DetectionResultFields.detection_masks, None) outputs.pop(fields.DetectionResultFields.detection_boundaries, None) outputs.pop(fields.DetectionResultFields.detection_keypoints, None) outputs.pop(fields.DetectionResultFields.raw_detection_boxes, None) outputs.pop(fields.DetectionResultFields.raw_detection_scores, None) outputs.pop(fields.DetectionResultFields.detection_anchor_indices, None) profile_inference_graph(tf.get_default_graph()) saver_kwargs = {} if use_moving_averages: if not temp_checkpoint_prefix: # This check is to be compatible with both version of SaverDef. if os.path.isfile(trained_checkpoint_prefix): saver_kwargs['write_version'] = saver_pb2.SaverDef.V1 temp_checkpoint_prefix = tempfile.NamedTemporaryFile().name else: temp_checkpoint_prefix = tempfile.mkdtemp() replace_variable_values_with_moving_averages( tf.get_default_graph(), trained_checkpoint_prefix, temp_checkpoint_prefix) checkpoint_to_use = temp_checkpoint_prefix else: checkpoint_to_use = trained_checkpoint_prefix saver = tf.train.Saver(**saver_kwargs) input_saver_def = saver.as_saver_def() write_graph_and_checkpoint( inference_graph_def=tf.get_default_graph().as_graph_def(), model_path=model_path, input_saver_def=input_saver_def, trained_checkpoint_prefix=checkpoint_to_use) if write_inference_graph: inference_graph_def = tf.get_default_graph().as_graph_def() inference_graph_path = os.path.join(output_directory, 'inference_graph.pbtxt') for node in inference_graph_def.node: node.device = '' with tf.gfile.GFile(inference_graph_path, 'wb') as f: f.write(str(inference_graph_def)) if additional_output_tensor_names is not None: output_node_names = ','.join(outputs.keys() + additional_output_tensor_names) else: output_node_names = ','.join(outputs.keys()) write_saved_model(saved_model_path, trained_checkpoint_prefix, placeholder_tensor, outputs)