Beispiel #1
0
    def test_get_configs_from_pipeline_file(self):
        """Test that proto configs can be read from pipeline config file."""
        pipeline_config_path = os.path.join(self.get_temp_dir(),
                                            "pipeline.config")

        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        pipeline_config.model.faster_rcnn.num_classes = 10
        pipeline_config.train_config.batch_size = 32
        pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
        pipeline_config.eval_config.num_examples = 20
        pipeline_config.eval_input_reader.add().queue_capacity = 100

        _write_config(pipeline_config, pipeline_config_path)

        configs = config_util.get_configs_from_pipeline_file(
            pipeline_config_path)
        self.assertProtoEquals(pipeline_config.model, configs["model"])
        self.assertProtoEquals(pipeline_config.train_config,
                               configs["train_config"])
        self.assertProtoEquals(pipeline_config.train_input_reader,
                               configs["train_input_config"])
        self.assertProtoEquals(pipeline_config.eval_config,
                               configs["eval_config"])
        self.assertProtoEquals(pipeline_config.eval_input_reader,
                               configs["eval_input_configs"])
Beispiel #2
0
    def testNewFocalLossParameters(self):
        """Tests that the loss weight ratio is updated appropriately."""
        original_alpha = 1.0
        original_gamma = 1.0
        new_alpha = 0.3
        new_gamma = 2.0
        hparams = tf.contrib.training.HParams(focal_loss_alpha=new_alpha,
                                              focal_loss_gamma=new_gamma)
        pipeline_config_path = os.path.join(self.get_temp_dir(),
                                            "pipeline.config")

        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        classification_loss = pipeline_config.model.ssd.loss.classification_loss
        classification_loss.weighted_sigmoid_focal.alpha = original_alpha
        classification_loss.weighted_sigmoid_focal.gamma = original_gamma
        _write_config(pipeline_config, pipeline_config_path)

        configs = config_util.get_configs_from_pipeline_file(
            pipeline_config_path)
        configs = config_util.merge_external_params_with_configs(
            configs, hparams)
        classification_loss = configs["model"].ssd.loss.classification_loss
        self.assertAlmostEqual(
            new_alpha, classification_loss.weighted_sigmoid_focal.alpha)
        self.assertAlmostEqual(
            new_gamma, classification_loss.weighted_sigmoid_focal.gamma)
Beispiel #3
0
    def testUpdateMaskTypeForAllInputConfigs(self):
        original_mask_type = input_reader_pb2.NUMERICAL_MASKS
        new_mask_type = input_reader_pb2.PNG_MASKS

        pipeline_config_path = os.path.join(self.get_temp_dir(),
                                            "pipeline.config")
        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        train_config = pipeline_config.train_input_reader
        train_config.mask_type = original_mask_type
        eval_1 = pipeline_config.eval_input_reader.add()
        eval_1.mask_type = original_mask_type
        eval_1.name = "eval_1"
        eval_2 = pipeline_config.eval_input_reader.add()
        eval_2.mask_type = original_mask_type
        eval_2.name = "eval_2"
        _write_config(pipeline_config, pipeline_config_path)

        configs = config_util.get_configs_from_pipeline_file(
            pipeline_config_path)
        override_dict = {"mask_type": new_mask_type}
        configs = config_util.merge_external_params_with_configs(
            configs, kwargs_dict=override_dict)

        self.assertEqual(configs["train_input_config"].mask_type,
                         new_mask_type)
        for eval_input_config in configs["eval_input_configs"]:
            self.assertEqual(eval_input_config.mask_type, new_mask_type)
Beispiel #4
0
 def testKeyValueOverrideBadKey(self):
     """Tests that overwriting with a bad key causes an exception."""
     pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
     configs = self._create_and_load_test_configs(pipeline_config)
     hparams = tf.contrib.training.HParams(
         **{"train_config.no_such_field": 10})
     with self.assertRaises(ValueError):
         config_util.merge_external_params_with_configs(configs, hparams)
Beispiel #5
0
    def testCheckAndParseInputConfigKey(self):
        pipeline_config_path = os.path.join(self.get_temp_dir(),
                                            "pipeline.config")
        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        pipeline_config.eval_input_reader.add().name = "eval_1"
        pipeline_config.eval_input_reader.add().name = "eval_2"
        _write_config(pipeline_config, pipeline_config_path)
        configs = config_util.get_configs_from_pipeline_file(
            pipeline_config_path)

        specific_shuffle_update_key = "eval_input_configs:eval_2:shuffle"
        is_valid_input_config_key, key_name, input_name, field_name = (
            config_util.check_and_parse_input_config_key(
                configs, specific_shuffle_update_key))
        self.assertTrue(is_valid_input_config_key)
        self.assertEqual(key_name, "eval_input_configs")
        self.assertEqual(input_name, "eval_2")
        self.assertEqual(field_name, "shuffle")

        legacy_shuffle_update_key = "eval_shuffle"
        is_valid_input_config_key, key_name, input_name, field_name = (
            config_util.check_and_parse_input_config_key(
                configs, legacy_shuffle_update_key))
        self.assertTrue(is_valid_input_config_key)
        self.assertEqual(key_name, "eval_input_configs")
        self.assertEqual(input_name, None)
        self.assertEqual(field_name, "shuffle")

        non_input_config_update_key = "label_map_path"
        is_valid_input_config_key, key_name, input_name, field_name = (
            config_util.check_and_parse_input_config_key(
                configs, non_input_config_update_key))
        self.assertFalse(is_valid_input_config_key)
        self.assertEqual(key_name, None)
        self.assertEqual(input_name, None)
        self.assertEqual(field_name, "label_map_path")

        with self.assertRaisesRegexp(
                ValueError, "Invalid key format when overriding configs."):
            config_util.check_and_parse_input_config_key(
                configs, "train_input_config:shuffle")

        with self.assertRaisesRegexp(
                ValueError, "Invalid key_name when overriding input config."):
            config_util.check_and_parse_input_config_key(
                configs, "invalid_key_name:train_name:shuffle")

        with self.assertRaisesRegexp(
                ValueError,
                "Invalid input_name when overriding input config."):
            config_util.check_and_parse_input_config_key(
                configs, "eval_input_configs:unknown_eval_name:shuffle")

        with self.assertRaisesRegexp(
                ValueError,
                "Invalid field_name when overriding input config."):
            config_util.check_and_parse_input_config_key(
                configs, "eval_input_configs:eval_2:unknown_field_name")
Beispiel #6
0
 def testOverwriteBatchSizeWithBadValueType(self):
     """Tests that overwriting with a bad valuye type causes an exception."""
     pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
     pipeline_config.train_config.batch_size = 2
     configs = self._create_and_load_test_configs(pipeline_config)
     # Type should be an integer, but we're passing a string "10".
     hparams = tf.contrib.training.HParams(
         **{"train_config.batch_size": "10"})
     with self.assertRaises(TypeError):
         config_util.merge_external_params_with_configs(configs, hparams)
Beispiel #7
0
 def testOverwriteBatchSizeWithKeyValue(self):
     """Tests that batch size is overwritten based on key/value."""
     pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
     pipeline_config.train_config.batch_size = 2
     configs = self._create_and_load_test_configs(pipeline_config)
     hparams = tf.contrib.training.HParams(
         **{"train_config.batch_size": 10})
     configs = config_util.merge_external_params_with_configs(
         configs, hparams)
     new_batch_size = configs["train_config"].batch_size
     self.assertEqual(10, new_batch_size)
Beispiel #8
0
    def testGetNumberOfClasses(self):
        """Tests that number of classes can be retrieved."""
        pipeline_config_path = os.path.join(self.get_temp_dir(),
                                            "pipeline.config")
        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        pipeline_config.model.faster_rcnn.num_classes = 20
        _write_config(pipeline_config, pipeline_config_path)

        configs = config_util.get_configs_from_pipeline_file(
            pipeline_config_path)
        number_of_classes = config_util.get_number_of_classes(configs["model"])
        self.assertEqual(20, number_of_classes)
Beispiel #9
0
    def testUseMovingAverageForEval(self):
        use_moving_averages_orig = False
        pipeline_config_path = os.path.join(self.get_temp_dir(),
                                            "pipeline.config")

        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        pipeline_config.eval_config.use_moving_averages = use_moving_averages_orig
        _write_config(pipeline_config, pipeline_config_path)

        configs = config_util.get_configs_from_pipeline_file(
            pipeline_config_path)
        override_dict = {"eval_with_moving_averages": True}
        configs = config_util.merge_external_params_with_configs(
            configs, kwargs_dict=override_dict)
        self.assertEqual(True, configs["eval_config"].use_moving_averages)
Beispiel #10
0
def main(argv):
    del argv  # Unused.
    flags.mark_flag_as_required('output_directory')
    flags.mark_flag_as_required('pipeline_config_path')
    flags.mark_flag_as_required('trained_checkpoint_prefix')

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()

    with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
        text_format.Merge(f.read(), pipeline_config)
    text_format.Merge(FLAGS.config_override, pipeline_config)
    export_tflite_ssd_graph_lib.export_tflite_graph(
        pipeline_config, FLAGS.trained_checkpoint_prefix,
        FLAGS.output_directory, FLAGS.add_postprocessing_op,
        FLAGS.max_detections, FLAGS.max_classes_per_detection)
Beispiel #11
0
def main(_):
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
        text_format.Merge(f.read(), pipeline_config)
    text_format.Merge(FLAGS.config_override, pipeline_config)
    if FLAGS.input_shape:
        input_shape = [
            int(dim) if dim != '-1' else None
            for dim in FLAGS.input_shape.split(',')
        ]
    else:
        input_shape = None
    exporter.export_inference_graph(FLAGS.input_type, pipeline_config,
                                    FLAGS.trained_checkpoint_prefix,
                                    FLAGS.output_directory, input_shape)
Beispiel #12
0
    def test_save_pipeline_config(self):
        """Tests that the pipeline config is properly saved to disk."""
        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        pipeline_config.model.faster_rcnn.num_classes = 10
        pipeline_config.train_config.batch_size = 32
        pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
        pipeline_config.eval_config.num_examples = 20
        pipeline_config.eval_input_reader.add().queue_capacity = 100

        config_util.save_pipeline_config(pipeline_config, self.get_temp_dir())
        configs = config_util.get_configs_from_pipeline_file(
            os.path.join(self.get_temp_dir(), "pipeline.config"))
        pipeline_config_reconstructed = (
            config_util.create_pipeline_proto_from_configs(configs))

        self.assertEqual(pipeline_config, pipeline_config_reconstructed)
Beispiel #13
0
    def testNewBatchSizeWithClipping(self):
        """Tests that batch size is clipped to 1 from below."""
        original_batch_size = 2
        hparams = tf.contrib.training.HParams(batch_size=0.5)
        pipeline_config_path = os.path.join(self.get_temp_dir(),
                                            "pipeline.config")

        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        pipeline_config.train_config.batch_size = original_batch_size
        _write_config(pipeline_config, pipeline_config_path)

        configs = config_util.get_configs_from_pipeline_file(
            pipeline_config_path)
        configs = config_util.merge_external_params_with_configs(
            configs, hparams)
        new_batch_size = configs["train_config"].batch_size
        self.assertEqual(1, new_batch_size)  # Clipped to 1.0.
Beispiel #14
0
    def testNewBatchSize(self):
        """Tests that batch size is updated appropriately."""
        original_batch_size = 2
        hparams = tf.contrib.training.HParams(batch_size=16)
        pipeline_config_path = os.path.join(self.get_temp_dir(),
                                            "pipeline.config")

        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        pipeline_config.train_config.batch_size = original_batch_size
        _write_config(pipeline_config, pipeline_config_path)

        configs = config_util.get_configs_from_pipeline_file(
            pipeline_config_path)
        configs = config_util.merge_external_params_with_configs(
            configs, hparams)
        new_batch_size = configs["train_config"].batch_size
        self.assertEqual(16, new_batch_size)
Beispiel #15
0
 def test_export_tflite_graph_without_moving_averages(self):
     pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
     pipeline_config.eval_config.use_moving_averages = False
     pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.height = 10
     pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.width = 10
     pipeline_config.model.ssd.num_classes = 2
     pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.y_scale = 10.0
     pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.x_scale = 10.0
     pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.height_scale = 5.0
     pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.width_scale = 5.0
     tflite_graph_file = self._export_graph(pipeline_config)
     self.assertTrue(os.path.exists(tflite_graph_file))
     (box_encodings_np, class_predictions_np
      ) = self._import_graph_and_run_inference(tflite_graph_file)
     self.assertAllClose(box_encodings_np,
                         [[[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 0.8, 0.8]]])
     self.assertAllClose(class_predictions_np, [[[0.7, 0.6], [0.9, 0.0]]])
Beispiel #16
0
def get_configs_from_pipeline_file(pipeline_config_path):
    """Reads config from a file containing pipeline_pb2.TrainEvalPipelineConfig.

  Args:
    pipeline_config_path: Path to pipeline_pb2.TrainEvalPipelineConfig text
      proto.

  Returns:
    Dictionary of configuration objects. Keys are `model`, `train_config`,
      `train_input_config`, `eval_config`, `eval_input_config`. Value are the
      corresponding config objects.
  """
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    with tf.gfile.GFile(pipeline_config_path, "r") as f:
        proto_str = f.read()
        text_format.Merge(proto_str, pipeline_config)
    return create_configs_from_pipeline_proto(pipeline_config)
Beispiel #17
0
    def testTrainShuffle(self):
        """Tests that `train_shuffle` keyword arguments are applied correctly."""
        original_shuffle = True
        desired_shuffle = False

        pipeline_config_path = os.path.join(self.get_temp_dir(),
                                            "pipeline.config")
        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        pipeline_config.train_input_reader.shuffle = original_shuffle
        _write_config(pipeline_config, pipeline_config_path)

        configs = config_util.get_configs_from_pipeline_file(
            pipeline_config_path)
        override_dict = {"train_shuffle": desired_shuffle}
        configs = config_util.merge_external_params_with_configs(
            configs, kwargs_dict=override_dict)
        train_shuffle = configs["train_input_config"].shuffle
        self.assertEqual(desired_shuffle, train_shuffle)
Beispiel #18
0
    def test_create_pipeline_proto_from_configs(self):
        """Tests that proto can be reconstructed from configs dictionary."""
        pipeline_config_path = os.path.join(self.get_temp_dir(),
                                            "pipeline.config")

        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        pipeline_config.model.faster_rcnn.num_classes = 10
        pipeline_config.train_config.batch_size = 32
        pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
        pipeline_config.eval_config.num_examples = 20
        pipeline_config.eval_input_reader.add().queue_capacity = 100
        _write_config(pipeline_config, pipeline_config_path)

        configs = config_util.get_configs_from_pipeline_file(
            pipeline_config_path)
        pipeline_config_reconstructed = (
            config_util.create_pipeline_proto_from_configs(configs))
        self.assertEqual(pipeline_config, pipeline_config_reconstructed)
Beispiel #19
0
    def testMergingKeywordArguments(self):
        """Tests that keyword arguments get merged as do hyperparameters."""
        original_num_train_steps = 100
        desired_num_train_steps = 10
        pipeline_config_path = os.path.join(self.get_temp_dir(),
                                            "pipeline.config")

        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        pipeline_config.train_config.num_steps = original_num_train_steps
        _write_config(pipeline_config, pipeline_config_path)

        configs = config_util.get_configs_from_pipeline_file(
            pipeline_config_path)
        override_dict = {"train_steps": desired_num_train_steps}
        configs = config_util.merge_external_params_with_configs(
            configs, kwargs_dict=override_dict)
        train_steps = configs["train_config"].num_steps
        self.assertEqual(desired_num_train_steps, train_steps)
Beispiel #20
0
    def testNewMomentumOptimizerValue(self):
        """Tests that new momentum value is updated appropriately."""
        original_momentum_value = 0.4
        hparams = tf.contrib.training.HParams(momentum_optimizer_value=1.1)
        pipeline_config_path = os.path.join(self.get_temp_dir(),
                                            "pipeline.config")

        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        optimizer_config = pipeline_config.train_config.optimizer.rms_prop_optimizer
        optimizer_config.momentum_optimizer_value = original_momentum_value
        _write_config(pipeline_config, pipeline_config_path)

        configs = config_util.get_configs_from_pipeline_file(
            pipeline_config_path)
        configs = config_util.merge_external_params_with_configs(
            configs, hparams)
        optimizer_config = configs["train_config"].optimizer.rms_prop_optimizer
        new_momentum_value = optimizer_config.momentum_optimizer_value
        self.assertAlmostEqual(1.0, new_momentum_value)  # Clipped to 1.0.
Beispiel #21
0
    def testNewTrainInputPathList(self):
        """Tests that train input path can be overwritten with multiple files."""
        original_train_path = ["path/to/data"]
        new_train_path = ["another/path/to/data", "yet/another/path/to/data"]
        pipeline_config_path = os.path.join(self.get_temp_dir(),
                                            "pipeline.config")

        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        reader_config = pipeline_config.train_input_reader.tf_record_input_reader
        reader_config.input_path.extend(original_train_path)
        _write_config(pipeline_config, pipeline_config_path)

        configs = config_util.get_configs_from_pipeline_file(
            pipeline_config_path)
        override_dict = {"train_input_path": new_train_path}
        configs = config_util.merge_external_params_with_configs(
            configs, kwargs_dict=override_dict)
        reader_config = configs["train_input_config"].tf_record_input_reader
        final_path = reader_config.input_path
        self.assertEqual(new_train_path, final_path)
Beispiel #22
0
    def testErrorOverwritingMultipleInputConfig(self):
        original_shuffle = False
        new_shuffle = True
        pipeline_config_path = os.path.join(self.get_temp_dir(),
                                            "pipeline.config")
        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        eval_1 = pipeline_config.eval_input_reader.add()
        eval_1.shuffle = original_shuffle
        eval_1.name = "eval_1"
        eval_2 = pipeline_config.eval_input_reader.add()
        eval_2.shuffle = original_shuffle
        eval_2.name = "eval_2"
        _write_config(pipeline_config, pipeline_config_path)

        configs = config_util.get_configs_from_pipeline_file(
            pipeline_config_path)
        override_dict = {"eval_shuffle": new_shuffle}
        with self.assertRaises(ValueError):
            configs = config_util.merge_external_params_with_configs(
                configs, kwargs_dict=override_dict)
Beispiel #23
0
    def testOverwriteAllEvalNumEpochs(self):
        original_num_epochs = 10
        new_num_epochs = 1

        pipeline_config_path = os.path.join(self.get_temp_dir(),
                                            "pipeline.config")
        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        pipeline_config.eval_input_reader.add(
        ).num_epochs = original_num_epochs
        pipeline_config.eval_input_reader.add(
        ).num_epochs = original_num_epochs
        _write_config(pipeline_config, pipeline_config_path)

        configs = config_util.get_configs_from_pipeline_file(
            pipeline_config_path)
        override_dict = {"eval_num_epochs": new_num_epochs}
        configs = config_util.merge_external_params_with_configs(
            configs, kwargs_dict=override_dict)
        for eval_input_config in configs["eval_input_configs"]:
            self.assertEqual(new_num_epochs, eval_input_config.num_epochs)
Beispiel #24
0
 def test_export_tflite_graph_with_softmax_score_conversion(self):
     pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
     pipeline_config.eval_config.use_moving_averages = False
     pipeline_config.model.ssd.post_processing.score_converter = (
         post_processing_pb2.PostProcessing.SOFTMAX)
     pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.height = 10
     pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.width = 10
     pipeline_config.model.ssd.num_classes = 2
     pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.y_scale = 10.0
     pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.x_scale = 10.0
     pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.height_scale = 5.0
     pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.width_scale = 5.0
     tflite_graph_file = self._export_graph(pipeline_config)
     self.assertTrue(os.path.exists(tflite_graph_file))
     (box_encodings_np, class_predictions_np
      ) = self._import_graph_and_run_inference(tflite_graph_file)
     self.assertAllClose(box_encodings_np,
                         [[[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 0.8, 0.8]]])
     self.assertAllClose(class_predictions_np,
                         [[[0.524979, 0.475021], [0.710949, 0.28905]]])
Beispiel #25
0
    def test_create_configs_from_pipeline_proto(self):
        """Tests creating configs dictionary from pipeline proto."""

        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        pipeline_config.model.faster_rcnn.num_classes = 10
        pipeline_config.train_config.batch_size = 32
        pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
        pipeline_config.eval_config.num_examples = 20
        pipeline_config.eval_input_reader.add().queue_capacity = 100

        configs = config_util.create_configs_from_pipeline_proto(
            pipeline_config)
        self.assertProtoEquals(pipeline_config.model, configs["model"])
        self.assertProtoEquals(pipeline_config.train_config,
                               configs["train_config"])
        self.assertProtoEquals(pipeline_config.train_input_reader,
                               configs["train_input_config"])
        self.assertProtoEquals(pipeline_config.eval_config,
                               configs["eval_config"])
        self.assertProtoEquals(pipeline_config.eval_input_reader,
                               configs["eval_input_configs"])
Beispiel #26
0
    def testOverwriteAllEvalSampling(self):
        original_num_eval_examples = 1
        new_num_eval_examples = 10

        pipeline_config_path = os.path.join(self.get_temp_dir(),
                                            "pipeline.config")
        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        pipeline_config.eval_input_reader.add().sample_1_of_n_examples = (
            original_num_eval_examples)
        pipeline_config.eval_input_reader.add().sample_1_of_n_examples = (
            original_num_eval_examples)
        _write_config(pipeline_config, pipeline_config_path)

        configs = config_util.get_configs_from_pipeline_file(
            pipeline_config_path)
        override_dict = {"sample_1_of_n_eval_examples": new_num_eval_examples}
        configs = config_util.merge_external_params_with_configs(
            configs, kwargs_dict=override_dict)
        for eval_input_config in configs["eval_input_configs"]:
            self.assertEqual(new_num_eval_examples,
                             eval_input_config.sample_1_of_n_examples)
Beispiel #27
0
def create_pipeline_proto_from_configs(configs):
    """Creates a pipeline_pb2.TrainEvalPipelineConfig from configs dictionary.

  This function performs the inverse operation of
  create_configs_from_pipeline_proto().

  Args:
    configs: Dictionary of configs. See get_configs_from_pipeline_file().

  Returns:
    A fully populated pipeline_pb2.TrainEvalPipelineConfig.
  """
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.model.CopyFrom(configs["model"])
    pipeline_config.train_config.CopyFrom(configs["train_config"])
    pipeline_config.train_input_reader.CopyFrom(configs["train_input_config"])
    pipeline_config.eval_config.CopyFrom(configs["eval_config"])
    pipeline_config.eval_input_reader.extend(configs["eval_input_configs"])
    if "graph_rewriter_config" in configs:
        pipeline_config.graph_rewriter.CopyFrom(
            configs["graph_rewriter_config"])
    return pipeline_config
Beispiel #28
0
    def testOverWriteRetainOriginalImages(self):
        """Tests that `train_shuffle` keyword arguments are applied correctly."""
        original_retain_original_images = True
        desired_retain_original_images = False

        pipeline_config_path = os.path.join(self.get_temp_dir(),
                                            "pipeline.config")
        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        pipeline_config.eval_config.retain_original_images = (
            original_retain_original_images)
        _write_config(pipeline_config, pipeline_config_path)

        configs = config_util.get_configs_from_pipeline_file(
            pipeline_config_path)
        override_dict = {
            "retain_original_images_in_eval": desired_retain_original_images
        }
        configs = config_util.merge_external_params_with_configs(
            configs, kwargs_dict=override_dict)
        retain_original_images = configs["eval_config"].retain_original_images
        self.assertEqual(desired_retain_original_images,
                         retain_original_images)
Beispiel #29
0
 def test_export_tflite_graph_with_postprocessing_op(self):
     pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
     pipeline_config.eval_config.use_moving_averages = False
     pipeline_config.model.ssd.post_processing.score_converter = (
         post_processing_pb2.PostProcessing.SIGMOID)
     pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.height = 10
     pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.width = 10
     pipeline_config.model.ssd.num_classes = 2
     pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.y_scale = 10.0
     pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.x_scale = 10.0
     pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.height_scale = 5.0
     pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.width_scale = 5.0
     tflite_graph_file = self._export_graph_with_postprocessing_op(
         pipeline_config)
     self.assertTrue(os.path.exists(tflite_graph_file))
     graph = tf.Graph()
     with graph.as_default():
         graph_def = tf.GraphDef()
         with tf.gfile.Open(tflite_graph_file, 'rb') as f:
             graph_def.ParseFromString(f.read())
         all_op_names = [node.name for node in graph_def.node]
         self.assertTrue('TFLite_Detection_PostProcess' in all_op_names)
         for node in graph_def.node:
             if node.name == 'TFLite_Detection_PostProcess':
                 self.assertTrue(node.attr['_output_quantized'].b is True)
                 self.assertTrue(
                     node.attr['_support_output_type_float_in_quantized_op']
                     .b is True)
                 self.assertTrue(node.attr['y_scale'].f == 10.0)
                 self.assertTrue(node.attr['x_scale'].f == 10.0)
                 self.assertTrue(node.attr['h_scale'].f == 5.0)
                 self.assertTrue(node.attr['w_scale'].f == 5.0)
                 self.assertTrue(node.attr['num_classes'].i == 2)
                 self.assertTrue(
                     all([
                         t == types_pb2.DT_FLOAT
                         for t in node.attr['_output_types'].list.type
                     ]))
Beispiel #30
0
    def testNewMaskType(self):
        """Tests that mask type can be overwritten in input readers."""
        original_mask_type = input_reader_pb2.NUMERICAL_MASKS
        new_mask_type = input_reader_pb2.PNG_MASKS
        pipeline_config_path = os.path.join(self.get_temp_dir(),
                                            "pipeline.config")

        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        train_input_reader = pipeline_config.train_input_reader
        train_input_reader.mask_type = original_mask_type
        eval_input_reader = pipeline_config.eval_input_reader.add()
        eval_input_reader.mask_type = original_mask_type
        _write_config(pipeline_config, pipeline_config_path)

        configs = config_util.get_configs_from_pipeline_file(
            pipeline_config_path)
        override_dict = {"mask_type": new_mask_type}
        configs = config_util.merge_external_params_with_configs(
            configs, kwargs_dict=override_dict)
        self.assertEqual(new_mask_type,
                         configs["train_input_config"].mask_type)
        self.assertEqual(new_mask_type,
                         configs["eval_input_configs"][0].mask_type)