def test_build_ssd_random_crop_fixed_aspect_ratio(self):
     preprocessor_text_proto = """
 ssd_random_crop_fixed_aspect_ratio {
   operations {
     min_object_covered: 0.0
     min_area: 0.5
     max_area: 1.0
     overlap_thresh: 0.0
     random_coef: 0.375
   }
   operations {
     min_object_covered: 0.25
     min_area: 0.5
     max_area: 1.0
     overlap_thresh: 0.25
     random_coef: 0.375
   }
   aspect_ratio: 0.875
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function,
                      preprocessor.ssd_random_crop_fixed_aspect_ratio)
     self.assertEqual(
         args, {
             'min_object_covered': [0.0, 0.25],
             'aspect_ratio': 0.875,
             'area_range': [(0.5, 1.0), (0.5, 1.0)],
             'overlap_thresh': [0.0, 0.25],
             'random_coef': [0.375, 0.375]
         })
 def test_build_random_crop_pad_image_with_optional_parameters(self):
     preprocessor_text_proto = """
 random_crop_pad_image {
   min_object_covered: 0.75
   min_aspect_ratio: 0.75
   max_aspect_ratio: 1.5
   min_area: 0.25
   max_area: 0.875
   overlap_thresh: 0.5
   random_coef: 0.125
   min_padded_size_ratio: 0.5
   min_padded_size_ratio: 0.75
   max_padded_size_ratio: 0.5
   max_padded_size_ratio: 0.75
   pad_color: 0.5
   pad_color: 0.5
   pad_color: 1.0
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_crop_pad_image)
     self.assertEqual(
         args, {
             'min_object_covered': 0.75,
             'aspect_ratio_range': (0.75, 1.5),
             'area_range': (0.25, 0.875),
             'overlap_thresh': 0.5,
             'random_coef': 0.125,
             'min_padded_size_ratio': (0.5, 0.75),
             'max_padded_size_ratio': (0.5, 0.75),
             'pad_color': (0.5, 0.5, 1.0)
         })
 def test_build_random_rotation90(self):
     preprocessor_text_proto = """
 random_rotation90 {}
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_rotation90)
     self.assertEqual(args, {})
 def test_build_scale_boxes_to_pixel_coordinates(self):
     preprocessor_text_proto = """
 scale_boxes_to_pixel_coordinates {}
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function,
                      preprocessor.scale_boxes_to_pixel_coordinates)
     self.assertEqual(args, {})
 def test_build_ssd_random_crop_empty_operations(self):
     preprocessor_text_proto = """
 ssd_random_crop {
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.ssd_random_crop)
     self.assertEqual(args, {})
 def test_build_subtract_channel_mean(self):
     preprocessor_text_proto = """
 subtract_channel_mean {
   means: [1.0, 2.0, 3.0]
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.subtract_channel_mean)
     self.assertEqual(args, {'means': [1.0, 2.0, 3.0]})
 def test_build_random_jitter_boxes(self):
     preprocessor_text_proto = """
 random_jitter_boxes {
   ratio: 0.1
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_jitter_boxes)
     self.assert_dictionary_close(args, {'ratio': 0.1})
 def test_build_random_rgb_to_gray(self):
     preprocessor_text_proto = """
 random_rgb_to_gray {
   probability: 0.8
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_rgb_to_gray)
     self.assert_dictionary_close(args, {'probability': 0.8})
 def test_build_random_distort_color(self):
     preprocessor_text_proto = """
 random_distort_color {
   color_ordering: 1
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_distort_color)
     self.assertEqual(args, {'color_ordering': 1})
 def test_build_random_adjust_hue(self):
     preprocessor_text_proto = """
 random_adjust_hue {
   max_delta: 0.01
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_adjust_hue)
     self.assert_dictionary_close(args, {'max_delta': 0.01})
 def test_build_random_resize_method(self):
     preprocessor_text_proto = """
 random_resize_method {
   target_height: 75
   target_width: 100
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_resize_method)
     self.assert_dictionary_close(args, {'target_size': [75, 100]})
 def test_build_random_pixel_value_scale(self):
     preprocessor_text_proto = """
 random_pixel_value_scale {
   minval: 0.8
   maxval: 1.2
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_pixel_value_scale)
     self.assert_dictionary_close(args, {'minval': 0.8, 'maxval': 1.2})
 def test_build_random_pad_image(self):
     preprocessor_text_proto = """
 random_pad_image {
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_pad_image)
     self.assertEqual(args, {
         'min_image_size': None,
         'max_image_size': None,
         'pad_color': None,
     })
 def test_build_random_crop_to_aspect_ratio(self):
     preprocessor_text_proto = """
 random_crop_to_aspect_ratio {
   aspect_ratio: 0.85
   overlap_thresh: 0.35
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_crop_to_aspect_ratio)
     self.assert_dictionary_close(args, {
         'aspect_ratio': 0.85,
         'overlap_thresh': 0.35
     })
 def test_build_random_adjust_saturation(self):
     preprocessor_text_proto = """
 random_adjust_saturation {
   min_delta: 0.75
   max_delta: 1.15
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_adjust_saturation)
     self.assert_dictionary_close(args, {
         'min_delta': 0.75,
         'max_delta': 1.15
     })
 def test_build_ssd_random_crop_pad(self):
     preprocessor_text_proto = """
 ssd_random_crop_pad {
   operations {
     min_object_covered: 0.0
     min_aspect_ratio: 0.875
     max_aspect_ratio: 1.125
     min_area: 0.5
     max_area: 1.0
     overlap_thresh: 0.0
     random_coef: 0.375
     min_padded_size_ratio: [1.0, 1.0]
     max_padded_size_ratio: [2.0, 2.0]
     pad_color_r: 0.5
     pad_color_g: 0.5
     pad_color_b: 0.5
   }
   operations {
     min_object_covered: 0.25
     min_aspect_ratio: 0.75
     max_aspect_ratio: 1.5
     min_area: 0.5
     max_area: 1.0
     overlap_thresh: 0.25
     random_coef: 0.375
     min_padded_size_ratio: [1.0, 1.0]
     max_padded_size_ratio: [2.0, 2.0]
     pad_color_r: 0.5
     pad_color_g: 0.5
     pad_color_b: 0.5
   }
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.ssd_random_crop_pad)
     self.assertEqual(
         args, {
             'min_object_covered': [0.0, 0.25],
             'aspect_ratio_range': [(0.875, 1.125), (0.75, 1.5)],
             'area_range': [(0.5, 1.0), (0.5, 1.0)],
             'overlap_thresh': [0.0, 0.25],
             'random_coef': [0.375, 0.375],
             'min_padded_size_ratio': [(1.0, 1.0), (1.0, 1.0)],
             'max_padded_size_ratio': [(2.0, 2.0), (2.0, 2.0)],
             'pad_color': [(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)]
         })
 def test_build_random_vertical_flip(self):
     preprocessor_text_proto = """
 random_vertical_flip {
   keypoint_flip_permutation: 1
   keypoint_flip_permutation: 0
   keypoint_flip_permutation: 2
   keypoint_flip_permutation: 3
   keypoint_flip_permutation: 5
   keypoint_flip_permutation: 4
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_vertical_flip)
     self.assertEqual(args,
                      {'keypoint_flip_permutation': (1, 0, 2, 3, 5, 4)})
 def test_build_resize_image(self):
     preprocessor_text_proto = """
 resize_image {
   new_height: 75
   new_width: 100
   method: BICUBIC
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.resize_image)
     self.assertEqual(
         args, {
             'new_height': 75,
             'new_width': 100,
             'method': tf.image.ResizeMethod.BICUBIC
         })
 def test_build_random_black_patches(self):
     preprocessor_text_proto = """
 random_black_patches {
   max_black_patches: 20
   probability: 0.95
   size_to_image_ratio: 0.12
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_black_patches)
     self.assert_dictionary_close(
         args, {
             'max_black_patches': 20,
             'probability': 0.95,
             'size_to_image_ratio': 0.12
         })
 def test_build_normalize_image(self):
     preprocessor_text_proto = """
 normalize_image {
   original_minval: 0.0
   original_maxval: 255.0
   target_minval: -1.0
   target_maxval: 1.0
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.normalize_image)
     self.assertEqual(
         args, {
             'original_minval': 0.0,
             'original_maxval': 255.0,
             'target_minval': -1.0,
             'target_maxval': 1.0,
         })
 def test_build_random_crop_image(self):
     preprocessor_text_proto = """
 random_crop_image {
   min_object_covered: 0.75
   min_aspect_ratio: 0.75
   max_aspect_ratio: 1.5
   min_area: 0.25
   max_area: 0.875
   overlap_thresh: 0.5
   random_coef: 0.125
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_crop_image)
     self.assertEqual(
         args, {
             'min_object_covered': 0.75,
             'aspect_ratio_range': (0.75, 1.5),
             'area_range': (0.25, 0.875),
             'overlap_thresh': 0.5,
             'random_coef': 0.125,
         })
Exemplo n.º 22
0
def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
          num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name,
          is_chief, train_dir):
    """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
  """

    detection_model = create_model_fn()
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        with tf.device(deploy_config.inputs_device()):
            input_queue = create_input_queue(
                train_config.batch_size // num_clones, create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options)

        # Gather initial summaries.
        # TODO(rathodv): See if summaries can be added/extracted from global tf
        # collections so that they don't have to be passed around.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        model_fn = functools.partial(_create_losses,
                                     create_model_fn=create_model_fn,
                                     train_config=train_config)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue])
        first_clone_scope = clones[0].scope

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer = optimizer_builder.build(
                train_config.optimizer, global_summaries)

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=train_config.worker_replicas)
            sync_optimizer = training_optimizer

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            var_map = detection_model.restore_map(
                from_detection_checkpoint=train_config.
                from_detection_checkpoint)
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map, train_config.fine_tune_checkpoint))
            init_saver = tf.train.Saver(available_var_map)

            def initializer_fn(sess):
                init_saver.restore(sess, train_config.fine_tune_checkpoint)

            init_fn = initializer_fn

        with tf.device(deploy_config.optimizer_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, training_optimizer, regularization_losses=None)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        slim.learning.train(
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            save_summaries_secs=120,
            sync_optimizer=sync_optimizer,
            saver=saver)