Exemple #1
0
        def transform_and_pad_input_data_fn(tensor_dict):
            """Combines transform and pad operation."""
            data_augmentation_options = [
                preprocessor_builder.build(step)
                for step in train_config.data_augmentation_options
            ]
            data_augmentation_fn = functools.partial(
                augment_input_data,
                data_augmentation_options=data_augmentation_options)
            model = model_builder.build(model_config, is_training=True)
            image_resizer_config = config_util.get_image_resizer_config(
                model_config)
            image_resizer_fn = image_resizer_builder.build(
                image_resizer_config)
            transform_data_fn = functools.partial(
                transform_input_data,
                model_preprocess_fn=model.preprocess,
                image_resizer_fn=image_resizer_fn,
                num_classes=config_util.get_number_of_classes(model_config),
                data_augmentation_fn=data_augmentation_fn,
                merge_multiple_boxes=train_config.merge_multiple_label_boxes,
                retain_original_image=train_config.retain_original_images,
                use_bfloat16=train_config.use_bfloat16)

            tensor_dict = pad_input_data_to_static_shapes(
                tensor_dict=transform_data_fn(tensor_dict),
                max_num_boxes=train_input_config.max_number_of_boxes,
                num_classes=config_util.get_number_of_classes(model_config),
                spatial_image_shape=config_util.get_spatial_image_size(
                    image_resizer_config))
            return (_get_features_dict(tensor_dict),
                    _get_labels_dict(tensor_dict))
  def test_write_graph_and_checkpoint(self):
    tmp_dir = self.get_temp_dir()
    trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
    self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
                                          use_moving_averages=False)
    output_directory = os.path.join(tmp_dir, 'output')
    model_path = os.path.join(output_directory, 'model.ckpt')
    meta_graph_path = model_path + '.meta'
    tf.gfile.MakeDirs(output_directory)
    with mock.patch.object(
        model_builder, 'build', autospec=True) as mock_builder:
      mock_builder.return_value = FakeModel(
          add_detection_keypoints=True, add_detection_masks=True)
      pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
      pipeline_config.eval_config.use_moving_averages = False
      detection_model = model_builder.build(pipeline_config.model,
                                            is_training=False)
      exporter._build_detection_graph(
          input_type='tf_example',
          detection_model=detection_model,
          input_shape=None,
          output_collection_name='inference_op',
          graph_hook_fn=None)
      saver = tf.train.Saver()
      input_saver_def = saver.as_saver_def()
      exporter.write_graph_and_checkpoint(
          inference_graph_def=tf.get_default_graph().as_graph_def(),
          model_path=model_path,
          input_saver_def=input_saver_def,
          trained_checkpoint_prefix=trained_checkpoint_prefix)

    tf_example_np = np.hstack([self._create_tf_example(
        np.ones((4, 4, 3)).astype(np.uint8))] * 2)
    with tf.Graph().as_default() as od_graph:
      with self.test_session(graph=od_graph) as sess:
        new_saver = tf.train.import_meta_graph(meta_graph_path)
        new_saver.restore(sess, model_path)

        tf_example = od_graph.get_tensor_by_name('tf_example:0')
        boxes = od_graph.get_tensor_by_name('detection_boxes:0')
        scores = od_graph.get_tensor_by_name('detection_scores:0')
        classes = od_graph.get_tensor_by_name('detection_classes:0')
        keypoints = od_graph.get_tensor_by_name('detection_keypoints:0')
        masks = od_graph.get_tensor_by_name('detection_masks:0')
        num_detections = od_graph.get_tensor_by_name('num_detections:0')
        (boxes_np, scores_np, classes_np, keypoints_np, masks_np,
         num_detections_np) = sess.run(
             [boxes, scores, classes, keypoints, masks, num_detections],
             feed_dict={tf_example: tf_example_np})
        self.assertAllClose(boxes_np, [[[0.0, 0.0, 0.5, 0.5],
                                        [0.5, 0.5, 0.8, 0.8]],
                                       [[0.5, 0.5, 1.0, 1.0],
                                        [0.0, 0.0, 0.0, 0.0]]])
        self.assertAllClose(scores_np, [[0.7, 0.6],
                                        [0.9, 0.0]])
        self.assertAllClose(classes_np, [[1, 2],
                                         [2, 1]])
        self.assertAllClose(keypoints_np, np.arange(48).reshape([2, 2, 6, 2]))
        self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
        self.assertAllClose(num_detections_np, [2, 1])
Exemple #3
0
 def test_create_faster_rcnn_model_from_config_with_example_miner(self):
     model_text_proto = """
   faster_rcnn {
     num_classes: 3
     feature_extractor {
       type: 'faster_rcnn_inception_resnet_v2'
     }
     image_resizer {
       keep_aspect_ratio_resizer {
         min_dimension: 600
         max_dimension: 1024
       }
     }
     first_stage_anchor_generator {
       grid_anchor_generator {
         scales: [0.25, 0.5, 1.0, 2.0]
         aspect_ratios: [0.5, 1.0, 2.0]
         height_stride: 16
         width_stride: 16
       }
     }
     first_stage_box_predictor_conv_hyperparams {
       regularizer {
         l2_regularizer {
         }
       }
       initializer {
         truncated_normal_initializer {
         }
       }
     }
     second_stage_box_predictor {
       mask_rcnn_box_predictor {
         fc_hyperparams {
           op: FC
           regularizer {
             l2_regularizer {
             }
           }
           initializer {
             truncated_normal_initializer {
             }
           }
         }
       }
     }
     hard_example_miner {
       num_hard_examples: 10
       iou_threshold: 0.99
     }
   }"""
     model_proto = model_pb2.DetectionModel()
     text_format.Merge(model_text_proto, model_proto)
     model = model_builder.build(model_proto, is_training=True)
     self.assertIsNotNone(model._hard_example_miner)
Exemple #4
0
    def create_model(self, model_config):
        """Builds a DetectionModel based on the model config.

    Args:
      model_config: A model.proto object containing the config for the desired
        DetectionModel.

    Returns:
      DetectionModel based on the config.
    """
        return model_builder.build(model_config, is_training=True)
def export_inference_graph(input_type,
                           pipeline_config,
                           trained_checkpoint_prefix,
                           output_directory,
                           input_shape=None,
                           output_collection_name='inference_op',
                           additional_output_tensor_names=None,
                           write_inference_graph=False):
    """Exports inference graph for the model specified in the pipeline config.

  Args:
    input_type: Type of input for the graph. Can be one of ['image_tensor',
      'encoded_image_string_tensor', 'tf_example'].
    pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
    trained_checkpoint_prefix: Path to the trained checkpoint file.
    output_directory: Path to write outputs.
    input_shape: Sets a fixed shape for an `image_tensor` input. If not
      specified, will default to [None, None, None, 3].
    output_collection_name: Name of collection to add output tensors to.
      If None, does not add output tensors to a collection.
    additional_output_tensor_names: list of additional output
      tensors to include in the frozen graph.
    write_inference_graph: If true, writes inference graph to disk.
  """
    detection_model = model_builder.build(pipeline_config.model,
                                          is_training=False)
    graph_rewriter_fn = None
    if pipeline_config.HasField('graph_rewriter'):
        graph_rewriter_config = pipeline_config.graph_rewriter
        graph_rewriter_fn = graph_rewriter_builder.build(graph_rewriter_config,
                                                         is_training=False)
    _export_inference_graph(input_type,
                            detection_model,
                            pipeline_config.eval_config.use_moving_averages,
                            trained_checkpoint_prefix,
                            output_directory,
                            additional_output_tensor_names,
                            input_shape,
                            output_collection_name,
                            graph_hook_fn=graph_rewriter_fn,
                            write_inference_graph=write_inference_graph)
    pipeline_config.eval_config.use_moving_averages = False
    config_util.save_pipeline_config(pipeline_config, output_directory)
Exemple #6
0
    def _predict_input_fn(params=None):
        """Decodes serialized tf.Examples and returns `ServingInputReceiver`.

    Args:
      params: Parameter dictionary passed from the estimator.

    Returns:
      `ServingInputReceiver`.
    """
        del params
        example = tf.placeholder(dtype=tf.string, shape=[], name='tf_example')

        num_classes = config_util.get_number_of_classes(model_config)
        model = model_builder.build(model_config, is_training=False)
        image_resizer_config = config_util.get_image_resizer_config(
            model_config)
        image_resizer_fn = image_resizer_builder.build(image_resizer_config)

        transform_fn = functools.partial(transform_input_data,
                                         model_preprocess_fn=model.preprocess,
                                         image_resizer_fn=image_resizer_fn,
                                         num_classes=num_classes,
                                         data_augmentation_fn=None)

        decoder = tf_example_decoder.TfExampleDecoder(
            load_instance_masks=False,
            num_additional_channels=predict_input_config.
            num_additional_channels)
        input_dict = transform_fn(decoder.decode(example))
        images = tf.to_float(input_dict[fields.InputDataFields.image])
        images = tf.expand_dims(images, axis=0)
        true_image_shape = tf.expand_dims(
            input_dict[fields.InputDataFields.true_image_shape], axis=0)

        return tf.estimator.export.ServingInputReceiver(
            features={
                fields.InputDataFields.image: images,
                fields.InputDataFields.true_image_shape: true_image_shape
            },
            receiver_tensors={SERVING_FED_EXAMPLE_KEY: example})
Exemple #7
0
        def transform_and_pad_input_data_fn(tensor_dict):
            """Combines transform and pad operation."""
            num_classes = config_util.get_number_of_classes(model_config)
            model = model_builder.build(model_config, is_training=False)
            image_resizer_config = config_util.get_image_resizer_config(
                model_config)
            image_resizer_fn = image_resizer_builder.build(
                image_resizer_config)

            transform_data_fn = functools.partial(
                transform_input_data,
                model_preprocess_fn=model.preprocess,
                image_resizer_fn=image_resizer_fn,
                num_classes=num_classes,
                data_augmentation_fn=None,
                retain_original_image=eval_config.retain_original_images)
            tensor_dict = pad_input_data_to_static_shapes(
                tensor_dict=transform_data_fn(tensor_dict),
                max_num_boxes=eval_input_config.max_number_of_boxes,
                num_classes=config_util.get_number_of_classes(model_config),
                spatial_image_shape=config_util.get_spatial_image_size(
                    image_resizer_config))
            return (_get_features_dict(tensor_dict),
                    _get_labels_dict(tensor_dict))
Exemple #8
0
    def test_create_ssd_resnet_v1_ppn_model_from_config(self):
        model_text_proto = """
      ssd {
        feature_extractor {
          type: 'ssd_resnet_v1_50_ppn'
          conv_hyperparams {
            regularizer {
                l2_regularizer {
                }
              }
              initializer {
                truncated_normal_initializer {
                }
              }
          }
        }
        box_coder {
          mean_stddev_box_coder {
          }
        }
        matcher {
          bipartite_matcher {
          }
        }
        similarity_calculator {
          iou_similarity {
          }
        }
        anchor_generator {
          ssd_anchor_generator {
            aspect_ratios: 1.0
          }
        }
        image_resizer {
          fixed_shape_resizer {
            height: 320
            width: 320
          }
        }
        box_predictor {
          weight_shared_convolutional_box_predictor {
            depth: 1024
            class_prediction_bias_init: -4.6
            conv_hyperparams {
              activation: RELU_6,
              regularizer {
                l2_regularizer {
                  weight: 0.0004
                }
              }
              initializer {
                variance_scaling_initializer {
                }
              }
            }
            num_layers_before_predictor: 2
            kernel_size: 1
          }
        }
        loss {
          classification_loss {
            weighted_softmax {
            }
          }
          localization_loss {
            weighted_l2 {
            }
          }
          classification_weight: 1.0
          localization_weight: 1.0
        }
      }"""
        model_proto = model_pb2.DetectionModel()
        text_format.Merge(model_text_proto, model_proto)

        for extractor_type, extractor_class in SSD_RESNET_V1_PPN_FEAT_MAPS.items(
        ):
            model_proto.ssd.feature_extractor.type = extractor_type
            model = model_builder.build(model_proto, is_training=True)
            self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
            self.assertIsInstance(model._feature_extractor, extractor_class)
Exemple #9
0
    def test_create_ssd_resnet_v1_fpn_model_from_config(self):
        model_text_proto = """
      ssd {
        feature_extractor {
          type: 'ssd_resnet50_v1_fpn'
          fpn {
            min_level: 3
            max_level: 7
          }
          conv_hyperparams {
            regularizer {
                l2_regularizer {
                }
              }
              initializer {
                truncated_normal_initializer {
                }
              }
          }
        }
        box_coder {
          faster_rcnn_box_coder {
          }
        }
        matcher {
          argmax_matcher {
          }
        }
        similarity_calculator {
          iou_similarity {
          }
        }
        encode_background_as_zeros: true
        anchor_generator {
          multiscale_anchor_generator {
            aspect_ratios: [1.0, 2.0, 0.5]
            scales_per_octave: 2
          }
        }
        image_resizer {
          fixed_shape_resizer {
            height: 320
            width: 320
          }
        }
        box_predictor {
          weight_shared_convolutional_box_predictor {
            depth: 32
            conv_hyperparams {
              regularizer {
                l2_regularizer {
                }
              }
              initializer {
                random_normal_initializer {
                }
              }
            }
            num_layers_before_predictor: 1
          }
        }
        normalize_loss_by_num_matches: true
        normalize_loc_loss_by_codesize: true
        loss {
          classification_loss {
            weighted_sigmoid_focal {
              alpha: 0.25
              gamma: 2.0
            }
          }
          localization_loss {
            weighted_smooth_l1 {
              delta: 0.1
            }
          }
          classification_weight: 1.0
          localization_weight: 1.0
        }
      }"""
        model_proto = model_pb2.DetectionModel()
        text_format.Merge(model_text_proto, model_proto)

        for extractor_type, extractor_class in SSD_RESNET_V1_FPN_FEAT_MAPS.items(
        ):
            model_proto.ssd.feature_extractor.type = extractor_type
            model = model_builder.build(model_proto, is_training=True)
            self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
            self.assertIsInstance(model._feature_extractor, extractor_class)
Exemple #10
0
 def test_create_rfcn_resnet_v1_model_from_config(self):
     model_text_proto = """
   faster_rcnn {
     num_classes: 3
     image_resizer {
       keep_aspect_ratio_resizer {
         min_dimension: 600
         max_dimension: 1024
       }
     }
     feature_extractor {
       type: 'faster_rcnn_resnet101'
     }
     first_stage_anchor_generator {
       grid_anchor_generator {
         scales: [0.25, 0.5, 1.0, 2.0]
         aspect_ratios: [0.5, 1.0, 2.0]
         height_stride: 16
         width_stride: 16
       }
     }
     first_stage_box_predictor_conv_hyperparams {
       regularizer {
         l2_regularizer {
         }
       }
       initializer {
         truncated_normal_initializer {
         }
       }
     }
     initial_crop_size: 14
     maxpool_kernel_size: 2
     maxpool_stride: 2
     second_stage_box_predictor {
       rfcn_box_predictor {
         conv_hyperparams {
           op: CONV
           regularizer {
             l2_regularizer {
             }
           }
           initializer {
             truncated_normal_initializer {
             }
           }
         }
       }
     }
     second_stage_post_processing {
       batch_non_max_suppression {
         score_threshold: 0.01
         iou_threshold: 0.6
         max_detections_per_class: 100
         max_total_detections: 300
       }
       score_converter: SOFTMAX
     }
   }"""
     model_proto = model_pb2.DetectionModel()
     text_format.Merge(model_text_proto, model_proto)
     for extractor_type, extractor_class in FRCNN_RESNET_FEAT_MAPS.items():
         model_proto.faster_rcnn.feature_extractor.type = extractor_type
         model = model_builder.build(model_proto, is_training=True)
         self.assertIsInstance(model, rfcn_meta_arch.RFCNMetaArch)
         self.assertIsInstance(model._feature_extractor, extractor_class)
Exemple #11
0
 def test_create_faster_rcnn_inception_v2_model_from_config(self):
     model_text_proto = """
   faster_rcnn {
     num_classes: 3
     image_resizer {
       keep_aspect_ratio_resizer {
         min_dimension: 600
         max_dimension: 1024
       }
     }
     feature_extractor {
       type: 'faster_rcnn_inception_v2'
     }
     first_stage_anchor_generator {
       grid_anchor_generator {
         scales: [0.25, 0.5, 1.0, 2.0]
         aspect_ratios: [0.5, 1.0, 2.0]
         height_stride: 16
         width_stride: 16
       }
     }
     first_stage_box_predictor_conv_hyperparams {
       regularizer {
         l2_regularizer {
         }
       }
       initializer {
         truncated_normal_initializer {
         }
       }
     }
     initial_crop_size: 14
     maxpool_kernel_size: 2
     maxpool_stride: 2
     second_stage_box_predictor {
       mask_rcnn_box_predictor {
         fc_hyperparams {
           op: FC
           regularizer {
             l2_regularizer {
             }
           }
           initializer {
             truncated_normal_initializer {
             }
           }
         }
       }
     }
     second_stage_post_processing {
       batch_non_max_suppression {
         score_threshold: 0.01
         iou_threshold: 0.6
         max_detections_per_class: 100
         max_total_detections: 300
       }
       score_converter: SOFTMAX
     }
   }"""
     model_proto = model_pb2.DetectionModel()
     text_format.Merge(model_text_proto, model_proto)
     model = model_builder.build(model_proto, is_training=True)
     self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch)
     self.assertIsInstance(
         model._feature_extractor,
         frcnn_inc_v2.FasterRCNNInceptionV2FeatureExtractor)
Exemple #12
0
 def test_create_faster_rcnn_resnet101_with_mask_prediction_enabled(
         self, use_matmul_crop_and_resize):
     model_text_proto = """
   faster_rcnn {
     num_classes: 3
     image_resizer {
       keep_aspect_ratio_resizer {
         min_dimension: 600
         max_dimension: 1024
       }
     }
     feature_extractor {
       type: 'faster_rcnn_resnet101'
     }
     first_stage_anchor_generator {
       grid_anchor_generator {
         scales: [0.25, 0.5, 1.0, 2.0]
         aspect_ratios: [0.5, 1.0, 2.0]
         height_stride: 16
         width_stride: 16
       }
     }
     first_stage_box_predictor_conv_hyperparams {
       regularizer {
         l2_regularizer {
         }
       }
       initializer {
         truncated_normal_initializer {
         }
       }
     }
     initial_crop_size: 14
     maxpool_kernel_size: 2
     maxpool_stride: 2
     second_stage_box_predictor {
       mask_rcnn_box_predictor {
         fc_hyperparams {
           op: FC
           regularizer {
             l2_regularizer {
             }
           }
           initializer {
             truncated_normal_initializer {
             }
           }
         }
         conv_hyperparams {
           regularizer {
             l2_regularizer {
             }
           }
           initializer {
             truncated_normal_initializer {
             }
           }
         }
         predict_instance_masks: true
       }
     }
     second_stage_mask_prediction_loss_weight: 3.0
     second_stage_post_processing {
       batch_non_max_suppression {
         score_threshold: 0.01
         iou_threshold: 0.6
         max_detections_per_class: 100
         max_total_detections: 300
       }
       score_converter: SOFTMAX
     }
   }"""
     model_proto = model_pb2.DetectionModel()
     text_format.Merge(model_text_proto, model_proto)
     model_proto.faster_rcnn.use_matmul_crop_and_resize = (
         use_matmul_crop_and_resize)
     model = model_builder.build(model_proto, is_training=True)
     self.assertAlmostEqual(model._second_stage_mask_loss_weight, 3.0)
def export_tflite_graph(pipeline_config,
                        trained_checkpoint_prefix,
                        output_dir,
                        add_postprocessing_op,
                        max_detections,
                        max_classes_per_detection,
                        detections_per_class=100,
                        use_regular_nms=False):
    """Exports a tflite compatible graph and anchors for ssd detection model.

  Anchors are written to a tensor and tflite compatible graph
  is written to output_dir/tflite_graph.pb.

  Args:
    pipeline_config: a pipeline.proto object containing the configuration for
      SSD model to export.
    trained_checkpoint_prefix: a file prefix for the checkpoint containing the
      trained parameters of the SSD model.
    output_dir: A directory to write the tflite graph and anchor file to.
    add_postprocessing_op: If add_postprocessing_op is true: frozen graph adds a
      TFLite_Detection_PostProcess custom op
    max_detections: Maximum number of detections (boxes) to show
    max_classes_per_detection: Number of classes to display per detection
    detections_per_class: In regular NonMaxSuppression, number of anchors used
    for NonMaxSuppression per class
    use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
      of Fast NMS.

  Raises:
    ValueError: if the pipeline config contains models other than ssd or uses an
      fixed_shape_resizer and provides a shape as well.
  """
    tf.gfile.MakeDirs(output_dir)
    if pipeline_config.model.WhichOneof('model') != 'ssd':
        raise ValueError('Only ssd models are supported in tflite. '
                         'Found {} in config'.format(
                             pipeline_config.model.WhichOneof('model')))

    num_classes = pipeline_config.model.ssd.num_classes
    nms_score_threshold = {
        pipeline_config.model.ssd.post_processing.batch_non_max_suppression.
        score_threshold
    }
    nms_iou_threshold = {
        pipeline_config.model.ssd.post_processing.batch_non_max_suppression.
        iou_threshold
    }
    scale_values = {}
    scale_values['y_scale'] = {
        pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.y_scale
    }
    scale_values['x_scale'] = {
        pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.x_scale
    }
    scale_values['h_scale'] = {
        pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.height_scale
    }
    scale_values['w_scale'] = {
        pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.width_scale
    }

    image_resizer_config = pipeline_config.model.ssd.image_resizer
    image_resizer = image_resizer_config.WhichOneof('image_resizer_oneof')
    num_channels = _DEFAULT_NUM_CHANNELS
    if image_resizer == 'fixed_shape_resizer':
        height = image_resizer_config.fixed_shape_resizer.height
        width = image_resizer_config.fixed_shape_resizer.width
        if image_resizer_config.fixed_shape_resizer.convert_to_grayscale:
            num_channels = 1
        shape = [1, height, width, num_channels]
    else:
        raise ValueError(
            'Only fixed_shape_resizer'
            'is supported with tflite. Found {}'.format(
                image_resizer_config.WhichOneof('image_resizer_oneof')))

    image = tf.placeholder(tf.float32,
                           shape=shape,
                           name='normalized_input_image_tensor')

    detection_model = model_builder.build(pipeline_config.model,
                                          is_training=False)
    predicted_tensors = detection_model.predict(image, true_image_shapes=None)
    # The score conversion occurs before the post-processing custom op
    _, score_conversion_fn = post_processing_builder.build(
        pipeline_config.model.ssd.post_processing)
    class_predictions = score_conversion_fn(
        predicted_tensors['class_predictions_with_background'])

    with tf.name_scope('raw_outputs'):
        # 'raw_outputs/box_encodings': a float32 tensor of shape [1, num_anchors, 4]
        #  containing the encoded box predictions. Note that these are raw
        #  predictions and no Non-Max suppression is applied on them and
        #  no decode center size boxes is applied to them.
        tf.identity(predicted_tensors['box_encodings'], name='box_encodings')
        # 'raw_outputs/class_predictions': a float32 tensor of shape
        #  [1, num_anchors, num_classes] containing the class scores for each anchor
        #  after applying score conversion.
        tf.identity(class_predictions, name='class_predictions')
    # 'anchors': a float32 tensor of shape
    #   [4, num_anchors] containing the anchors as a constant node.
    tf.identity(get_const_center_size_encoded_anchors(
        predicted_tensors['anchors']),
                name='anchors')

    # Add global step to the graph, so we know the training step number when we
    # evaluate the model.
    tf.train.get_or_create_global_step()

    # graph rewriter
    is_quantized = pipeline_config.HasField('graph_rewriter')
    if is_quantized:
        graph_rewriter_config = pipeline_config.graph_rewriter
        graph_rewriter_fn = graph_rewriter_builder.build(graph_rewriter_config,
                                                         is_training=False)
        graph_rewriter_fn()

    if pipeline_config.model.ssd.feature_extractor.HasField('fpn'):
        exporter.rewrite_nn_resize_op(is_quantized)

    # freeze the graph
    saver_kwargs = {}
    if pipeline_config.eval_config.use_moving_averages:
        saver_kwargs['write_version'] = saver_pb2.SaverDef.V1
        moving_average_checkpoint = tempfile.NamedTemporaryFile()
        exporter.replace_variable_values_with_moving_averages(
            tf.get_default_graph(), trained_checkpoint_prefix,
            moving_average_checkpoint.name)
        checkpoint_to_use = moving_average_checkpoint.name
    else:
        checkpoint_to_use = trained_checkpoint_prefix

    saver = tf.train.Saver(**saver_kwargs)
    input_saver_def = saver.as_saver_def()
    frozen_graph_def = exporter.freeze_graph_with_def_protos(
        input_graph_def=tf.get_default_graph().as_graph_def(),
        input_saver_def=input_saver_def,
        input_checkpoint=checkpoint_to_use,
        output_node_names=','.join([
            'raw_outputs/box_encodings', 'raw_outputs/class_predictions',
            'anchors'
        ]),
        restore_op_name='save/restore_all',
        filename_tensor_name='save/Const:0',
        clear_devices=True,
        output_graph='',
        initializer_nodes='')

    # Add new operation to do post processing in a custom op (TF Lite only)
    if add_postprocessing_op:
        transformed_graph_def = append_postprocessing_op(
            frozen_graph_def, max_detections, max_classes_per_detection,
            nms_score_threshold, nms_iou_threshold, num_classes, scale_values,
            detections_per_class, use_regular_nms)
    else:
        # Return frozen without adding post-processing custom op
        transformed_graph_def = frozen_graph_def

    binary_graph = os.path.join(output_dir, 'tflite_graph.pb')
    with tf.gfile.GFile(binary_graph, 'wb') as f:
        f.write(transformed_graph_def.SerializeToString())
    txt_graph = os.path.join(output_dir, 'tflite_graph.pbtxt')
    with tf.gfile.GFile(txt_graph, 'w') as f:
        f.write(str(transformed_graph_def))
  def test_write_saved_model(self):
    tmp_dir = self.get_temp_dir()
    trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
    self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
                                          use_moving_averages=False)
    output_directory = os.path.join(tmp_dir, 'output')
    saved_model_path = os.path.join(output_directory, 'saved_model')
    tf.gfile.MakeDirs(output_directory)
    with mock.patch.object(
        model_builder, 'build', autospec=True) as mock_builder:
      mock_builder.return_value = FakeModel(
          add_detection_keypoints=True, add_detection_masks=True)
      pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
      pipeline_config.eval_config.use_moving_averages = False
      detection_model = model_builder.build(pipeline_config.model,
                                            is_training=False)
      outputs, placeholder_tensor = exporter._build_detection_graph(
          input_type='tf_example',
          detection_model=detection_model,
          input_shape=None,
          output_collection_name='inference_op',
          graph_hook_fn=None)
      output_node_names = ','.join(outputs.keys())
      saver = tf.train.Saver()
      input_saver_def = saver.as_saver_def()
      frozen_graph_def = exporter.freeze_graph_with_def_protos(
          input_graph_def=tf.get_default_graph().as_graph_def(),
          input_saver_def=input_saver_def,
          input_checkpoint=trained_checkpoint_prefix,
          output_node_names=output_node_names,
          restore_op_name='save/restore_all',
          filename_tensor_name='save/Const:0',
          output_graph='',
          clear_devices=True,
          initializer_nodes='')
      exporter.write_saved_model(
          saved_model_path=saved_model_path,
          frozen_graph_def=frozen_graph_def,
          inputs=placeholder_tensor,
          outputs=outputs)

    tf_example_np = np.hstack([self._create_tf_example(
        np.ones((4, 4, 3)).astype(np.uint8))] * 2)
    with tf.Graph().as_default() as od_graph:
      with self.test_session(graph=od_graph) as sess:
        meta_graph = tf.saved_model.loader.load(
            sess, [tf.saved_model.tag_constants.SERVING], saved_model_path)

        signature = meta_graph.signature_def['serving_default']
        input_tensor_name = signature.inputs['inputs'].name
        tf_example = od_graph.get_tensor_by_name(input_tensor_name)

        boxes = od_graph.get_tensor_by_name(
            signature.outputs['detection_boxes'].name)
        scores = od_graph.get_tensor_by_name(
            signature.outputs['detection_scores'].name)
        classes = od_graph.get_tensor_by_name(
            signature.outputs['detection_classes'].name)
        keypoints = od_graph.get_tensor_by_name(
            signature.outputs['detection_keypoints'].name)
        masks = od_graph.get_tensor_by_name(
            signature.outputs['detection_masks'].name)
        num_detections = od_graph.get_tensor_by_name(
            signature.outputs['num_detections'].name)

        (boxes_np, scores_np, classes_np, keypoints_np, masks_np,
         num_detections_np) = sess.run(
             [boxes, scores, classes, keypoints, masks, num_detections],
             feed_dict={tf_example: tf_example_np})
        self.assertAllClose(boxes_np, [[[0.0, 0.0, 0.5, 0.5],
                                        [0.5, 0.5, 0.8, 0.8]],
                                       [[0.5, 0.5, 1.0, 1.0],
                                        [0.0, 0.0, 0.0, 0.0]]])
        self.assertAllClose(scores_np, [[0.7, 0.6],
                                        [0.9, 0.0]])
        self.assertAllClose(classes_np, [[1, 2],
                                         [2, 1]])
        self.assertAllClose(keypoints_np, np.arange(48).reshape([2, 2, 6, 2]))
        self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
        self.assertAllClose(num_detections_np, [2, 1])
  def test_write_frozen_graph(self):
    tmp_dir = self.get_temp_dir()
    trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
    self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
                                          use_moving_averages=True)
    output_directory = os.path.join(tmp_dir, 'output')
    inference_graph_path = os.path.join(output_directory,
                                        'frozen_inference_graph.pb')
    tf.gfile.MakeDirs(output_directory)
    with mock.patch.object(
        model_builder, 'build', autospec=True) as mock_builder:
      mock_builder.return_value = FakeModel(
          add_detection_keypoints=True, add_detection_masks=True)
      pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
      pipeline_config.eval_config.use_moving_averages = False
      detection_model = model_builder.build(pipeline_config.model,
                                            is_training=False)
      outputs, _ = exporter._build_detection_graph(
          input_type='tf_example',
          detection_model=detection_model,
          input_shape=None,
          output_collection_name='inference_op',
          graph_hook_fn=None)
      output_node_names = ','.join(outputs.keys())
      saver = tf.train.Saver()
      input_saver_def = saver.as_saver_def()
      exporter.freeze_graph_with_def_protos(
          input_graph_def=tf.get_default_graph().as_graph_def(),
          input_saver_def=input_saver_def,
          input_checkpoint=trained_checkpoint_prefix,
          output_node_names=output_node_names,
          restore_op_name='save/restore_all',
          filename_tensor_name='save/Const:0',
          output_graph=inference_graph_path,
          clear_devices=True,
          initializer_nodes='')

    inference_graph = self._load_inference_graph(inference_graph_path)
    tf_example_np = np.expand_dims(self._create_tf_example(
        np.ones((4, 4, 3)).astype(np.uint8)), axis=0)
    with self.test_session(graph=inference_graph) as sess:
      tf_example = inference_graph.get_tensor_by_name('tf_example:0')
      boxes = inference_graph.get_tensor_by_name('detection_boxes:0')
      scores = inference_graph.get_tensor_by_name('detection_scores:0')
      classes = inference_graph.get_tensor_by_name('detection_classes:0')
      keypoints = inference_graph.get_tensor_by_name('detection_keypoints:0')
      masks = inference_graph.get_tensor_by_name('detection_masks:0')
      num_detections = inference_graph.get_tensor_by_name('num_detections:0')
      (boxes_np, scores_np, classes_np, keypoints_np, masks_np,
       num_detections_np) = sess.run(
           [boxes, scores, classes, keypoints, masks, num_detections],
           feed_dict={tf_example: tf_example_np})
      self.assertAllClose(boxes_np, [[[0.0, 0.0, 0.5, 0.5],
                                      [0.5, 0.5, 0.8, 0.8]],
                                     [[0.5, 0.5, 1.0, 1.0],
                                      [0.0, 0.0, 0.0, 0.0]]])
      self.assertAllClose(scores_np, [[0.7, 0.6],
                                      [0.9, 0.0]])
      self.assertAllClose(classes_np, [[1, 2],
                                       [2, 1]])
      self.assertAllClose(keypoints_np, np.arange(48).reshape([2, 2, 6, 2]))
      self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
      self.assertAllClose(num_detections_np, [2, 1])