示例#1
0
 def test_build_random_crop_pad_image_with_optional_parameters(self):
     preprocessor_text_proto = """
 random_crop_pad_image {
   min_object_covered: 0.75
   min_aspect_ratio: 0.75
   max_aspect_ratio: 1.5
   min_area: 0.25
   max_area: 0.875
   overlap_thresh: 0.5
   random_coef: 0.125
   min_padded_size_ratio: 0.5
   min_padded_size_ratio: 0.75
   max_padded_size_ratio: 0.5
   max_padded_size_ratio: 0.75
   pad_color: 0.5
   pad_color: 0.5
   pad_color: 1.0
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_crop_pad_image)
     self.assertEqual(
         args, {
             'min_object_covered': 0.75,
             'aspect_ratio_range': (0.75, 1.5),
             'area_range': (0.25, 0.875),
             'overlap_thresh': 0.5,
             'random_coef': 0.125,
             'min_padded_size_ratio': (0.5, 0.75),
             'max_padded_size_ratio': (0.5, 0.75),
             'pad_color': (0.5, 0.5, 1.0)
         })
    def transform_and_pad_input_data_fn(tensor_dict):
      """Combines transform and pad operation."""
      data_augmentation_options = [
          preprocessor_builder.build(step)
          for step in train_config.data_augmentation_options
      ]
      data_augmentation_fn = functools.partial(
          augment_input_data,
          data_augmentation_options=data_augmentation_options)
      model = model_builder.build(model_config, is_training=True)
      image_resizer_config = config_util.get_image_resizer_config(model_config)
      image_resizer_fn = image_resizer_builder.build(image_resizer_config)
      transform_data_fn = functools.partial(
          transform_input_data, model_preprocess_fn=model.preprocess,
          image_resizer_fn=image_resizer_fn,
          num_classes=config_util.get_number_of_classes(model_config),
          data_augmentation_fn=data_augmentation_fn,
          merge_multiple_boxes=train_config.merge_multiple_label_boxes,
          retain_original_image=train_config.retain_original_images,
          use_multiclass_scores=train_config.use_multiclass_scores,
          use_bfloat16=train_config.use_bfloat16)

      tensor_dict = pad_input_data_to_static_shapes(
          tensor_dict=transform_data_fn(tensor_dict),
          max_num_boxes=train_input_config.max_number_of_boxes,
          num_classes=config_util.get_number_of_classes(model_config),
          spatial_image_shape=config_util.get_spatial_image_size(
              image_resizer_config))
      return (_get_features_dict(tensor_dict), _get_labels_dict(tensor_dict))
示例#3
0
 def test_build_ssd_random_crop_fixed_aspect_ratio(self):
     preprocessor_text_proto = """
 ssd_random_crop_fixed_aspect_ratio {
   operations {
     min_object_covered: 0.0
     min_area: 0.5
     max_area: 1.0
     overlap_thresh: 0.0
     random_coef: 0.375
   }
   operations {
     min_object_covered: 0.25
     min_area: 0.5
     max_area: 1.0
     overlap_thresh: 0.25
     random_coef: 0.375
   }
   aspect_ratio: 0.875
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function,
                      preprocessor.ssd_random_crop_fixed_aspect_ratio)
     self.assertEqual(
         args, {
             'min_object_covered': [0.0, 0.25],
             'aspect_ratio': 0.875,
             'area_range': [(0.5, 1.0), (0.5, 1.0)],
             'overlap_thresh': [0.0, 0.25],
             'random_coef': [0.375, 0.375]
         })
示例#4
0
 def test_build_random_rotation90(self):
     preprocessor_text_proto = """
 random_rotation90 {}
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_rotation90)
     self.assertEqual(args, {})
示例#5
0
 def test_build_rgb_to_gray(self):
     preprocessor_text_proto = """
 rgb_to_gray {}
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.rgb_to_gray)
     self.assertEqual(args, {})
示例#6
0
 def test_build_scale_boxes_to_pixel_coordinates(self):
     preprocessor_text_proto = """
 scale_boxes_to_pixel_coordinates {}
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function,
                      preprocessor.scale_boxes_to_pixel_coordinates)
     self.assertEqual(args, {})
示例#7
0
 def test_build_ssd_random_crop_empty_operations(self):
     preprocessor_text_proto = """
 ssd_random_crop {
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.ssd_random_crop)
     self.assertEqual(args, {})
示例#8
0
 def test_build_random_jitter_boxes(self):
     preprocessor_text_proto = """
 random_jitter_boxes {
   ratio: 0.1
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_jitter_boxes)
     self.assert_dictionary_close(args, {'ratio': 0.1})
示例#9
0
 def test_build_random_distort_color(self):
     preprocessor_text_proto = """
 random_distort_color {
   color_ordering: 1
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_distort_color)
     self.assertEqual(args, {'color_ordering': 1})
示例#10
0
 def test_build_random_adjust_hue(self):
     preprocessor_text_proto = """
 random_adjust_hue {
   max_delta: 0.01
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_adjust_hue)
     self.assert_dictionary_close(args, {'max_delta': 0.01})
示例#11
0
 def test_build_random_rgb_to_gray(self):
     preprocessor_text_proto = """
 random_rgb_to_gray {
   probability: 0.8
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_rgb_to_gray)
     self.assert_dictionary_close(args, {'probability': 0.8})
示例#12
0
 def test_build_subtract_channel_mean(self):
     preprocessor_text_proto = """
 subtract_channel_mean {
   means: [1.0, 2.0, 3.0]
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.subtract_channel_mean)
     self.assertEqual(args, {'means': [1.0, 2.0, 3.0]})
示例#13
0
 def test_build_random_pixel_value_scale(self):
     preprocessor_text_proto = """
 random_pixel_value_scale {
   minval: 0.8
   maxval: 1.2
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_pixel_value_scale)
     self.assert_dictionary_close(args, {'minval': 0.8, 'maxval': 1.2})
示例#14
0
 def test_build_random_resize_method(self):
     preprocessor_text_proto = """
 random_resize_method {
   target_height: 75
   target_width: 100
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_resize_method)
     self.assert_dictionary_close(args, {'target_size': [75, 100]})
示例#15
0
 def test_build_random_pad_image(self):
     preprocessor_text_proto = """
 random_pad_image {
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_pad_image)
     self.assertEqual(args, {
         'min_image_size': None,
         'max_image_size': None,
         'pad_color': None,
     })
示例#16
0
 def test_build_random_adjust_saturation(self):
     preprocessor_text_proto = """
 random_adjust_saturation {
   min_delta: 0.75
   max_delta: 1.15
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_adjust_saturation)
     self.assert_dictionary_close(args, {
         'min_delta': 0.75,
         'max_delta': 1.15
     })
示例#17
0
 def test_build_random_crop_to_aspect_ratio(self):
     preprocessor_text_proto = """
 random_crop_to_aspect_ratio {
   aspect_ratio: 0.85
   overlap_thresh: 0.35
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_crop_to_aspect_ratio)
     self.assert_dictionary_close(args, {
         'aspect_ratio': 0.85,
         'overlap_thresh': 0.35
     })
示例#18
0
 def test_build_ssd_random_crop_pad(self):
     preprocessor_text_proto = """
 ssd_random_crop_pad {
   operations {
     min_object_covered: 0.0
     min_aspect_ratio: 0.875
     max_aspect_ratio: 1.125
     min_area: 0.5
     max_area: 1.0
     overlap_thresh: 0.0
     random_coef: 0.375
     min_padded_size_ratio: [1.0, 1.0]
     max_padded_size_ratio: [2.0, 2.0]
     pad_color_r: 0.5
     pad_color_g: 0.5
     pad_color_b: 0.5
   }
   operations {
     min_object_covered: 0.25
     min_aspect_ratio: 0.75
     max_aspect_ratio: 1.5
     min_area: 0.5
     max_area: 1.0
     overlap_thresh: 0.25
     random_coef: 0.375
     min_padded_size_ratio: [1.0, 1.0]
     max_padded_size_ratio: [2.0, 2.0]
     pad_color_r: 0.5
     pad_color_g: 0.5
     pad_color_b: 0.5
   }
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.ssd_random_crop_pad)
     self.assertEqual(
         args, {
             'min_object_covered': [0.0, 0.25],
             'aspect_ratio_range': [(0.875, 1.125), (0.75, 1.5)],
             'area_range': [(0.5, 1.0), (0.5, 1.0)],
             'overlap_thresh': [0.0, 0.25],
             'random_coef': [0.375, 0.375],
             'min_padded_size_ratio': [(1.0, 1.0), (1.0, 1.0)],
             'max_padded_size_ratio': [(2.0, 2.0), (2.0, 2.0)],
             'pad_color': [(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)]
         })
示例#19
0
 def test_build_random_vertical_flip(self):
     preprocessor_text_proto = """
 random_vertical_flip {
   keypoint_flip_permutation: 1
   keypoint_flip_permutation: 0
   keypoint_flip_permutation: 2
   keypoint_flip_permutation: 3
   keypoint_flip_permutation: 5
   keypoint_flip_permutation: 4
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_vertical_flip)
     self.assertEqual(args,
                      {'keypoint_flip_permutation': (1, 0, 2, 3, 5, 4)})
示例#20
0
 def test_build_resize_image(self):
     preprocessor_text_proto = """
 resize_image {
   new_height: 75
   new_width: 100
   method: BICUBIC
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.resize_image)
     self.assertEqual(
         args, {
             'new_height': 75,
             'new_width': 100,
             'method': tf.image.ResizeMethod.BICUBIC
         })
示例#21
0
 def test_build_random_black_patches(self):
     preprocessor_text_proto = """
 random_black_patches {
   max_black_patches: 20
   probability: 0.95
   size_to_image_ratio: 0.12
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_black_patches)
     self.assert_dictionary_close(
         args, {
             'max_black_patches': 20,
             'probability': 0.95,
             'size_to_image_ratio': 0.12
         })
示例#22
0
 def test_build_normalize_image(self):
     preprocessor_text_proto = """
 normalize_image {
   original_minval: 0.0
   original_maxval: 255.0
   target_minval: -1.0
   target_maxval: 1.0
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.normalize_image)
     self.assertEqual(
         args, {
             'original_minval': 0.0,
             'original_maxval': 255.0,
             'target_minval': -1.0,
             'target_maxval': 1.0,
         })
示例#23
0
 def test_build_random_crop_image(self):
     preprocessor_text_proto = """
 random_crop_image {
   min_object_covered: 0.75
   min_aspect_ratio: 0.75
   max_aspect_ratio: 1.5
   min_area: 0.25
   max_area: 0.875
   overlap_thresh: 0.5
   random_coef: 0.125
 }
 """
     preprocessor_proto = preprocessor_pb2.PreprocessingStep()
     text_format.Merge(preprocessor_text_proto, preprocessor_proto)
     function, args = preprocessor_builder.build(preprocessor_proto)
     self.assertEqual(function, preprocessor.random_crop_image)
     self.assertEqual(
         args, {
             'min_object_covered': 0.75,
             'aspect_ratio_range': (0.75, 1.5),
             'area_range': (0.25, 0.875),
             'overlap_thresh': 0.5,
             'random_coef': 0.125,
         })
示例#24
0
def train(create_tensor_dict_fn,
          create_model_fn,
          train_config,
          master,
          task,
          num_clones,
          worker_replicas,
          clone_on_cpu,
          ps_tasks,
          worker_job_name,
          is_chief,
          train_dir,
          graph_hook_fn=None):
    """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
    graph_hook_fn: Optional function that is called after the inference graph is
      built (before optimization). This is helpful to perform additional changes
      to the training graph such as adding FakeQuant ops. The function should
      modify the default graph.

  Raises:
    ValueError: If both num_clones > 1 and train_config.sync_replicas is true.
  """

    detection_model = create_model_fn()
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        if num_clones != 1 and train_config.sync_replicas:
            raise ValueError('In Synchronous SGD mode num_clones must ',
                             'be 1. Found num_clones: {}'.format(num_clones))
        batch_size = train_config.batch_size // num_clones
        if train_config.sync_replicas:
            batch_size //= train_config.replicas_to_aggregate

        with tf.device(deploy_config.inputs_device()):
            input_queue = create_input_queue(
                batch_size, create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options)

        # Gather initial summaries.
        # TODO(rathodv): See if summaries can be added/extracted from global tf
        # collections so that they don't have to be passed around.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        model_fn = functools.partial(_create_losses,
                                     create_model_fn=create_model_fn,
                                     train_config=train_config)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue])
        first_clone_scope = clones[0].scope

        if graph_hook_fn:
            with tf.device(deploy_config.variables_device()):
                graph_hook_fn()

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                train_config.optimizer)
            for var in optimizer_summary_vars:
                tf.summary.scalar(var.op.name, var, family='LearningRate')

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.train.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=worker_replicas)
            sync_optimizer = training_optimizer

        with tf.device(deploy_config.optimizer_device()):
            regularization_losses = (
                None if train_config.add_regularization_loss else [])
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones,
                training_optimizer,
                regularization_losses=regularization_losses)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops, name='update_barrier')
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram('ModelVars/' + model_var.op.name,
                                     model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar('Losses/' + loss_tensor.op.name,
                                  loss_tensor))
        global_summaries.add(
            tf.summary.scalar('Losses/TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            if not train_config.fine_tune_checkpoint_type:
                # train_config.from_detection_checkpoint field is deprecated. For
                # backward compatibility, fine_tune_checkpoint_type is set based on
                # from_detection_checkpoint.
                if train_config.from_detection_checkpoint:
                    train_config.fine_tune_checkpoint_type = 'detection'
                else:
                    train_config.fine_tune_checkpoint_type = 'classification'
            var_map = detection_model.restore_map(
                fine_tune_checkpoint_type=train_config.
                fine_tune_checkpoint_type,
                load_all_detection_checkpoint_vars=(
                    train_config.load_all_detection_checkpoint_vars))
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map,
                    train_config.fine_tune_checkpoint,
                    include_global_step=False))
            init_saver = tf.train.Saver(available_var_map)

            def initializer_fn(sess):
                init_saver.restore(sess, train_config.fine_tune_checkpoint)

            init_fn = initializer_fn

        slim.learning.train(
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            save_summaries_secs=120,
            sync_optimizer=sync_optimizer,
            saver=saver)
示例#25
0
def train(create_tensor_dict_fn_list, create_model_fn, train_config, master,
          task, num_clones, worker_replicas, clone_on_cpu, ps_tasks,
          worker_job_name, is_chief, train_dir):
    """Training function for models.
  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
  """
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = tf.train.create_global_step()

        with tf.device(deploy_config.inputs_device()), \
             tf.name_scope('Input'):
            input_queue_list = []
            for i, create_tensor_dict_fn in enumerate(
                    create_tensor_dict_fn_list):
                input_queue_list.append(
                    _create_input_queue(
                        train_config.batch_size[i] // num_clones,
                        create_tensor_dict_fn,
                        train_config.batch_queue_capacity,
                        train_config.num_batch_queue_threads,
                        train_config.prefetch_queue_capacity,
                        data_augmentation_options))

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        model_fn = functools.partial(_create_losses,
                                     create_model_fn=create_model_fn)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue_list])
        first_clone_scope = clones[0].scope

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()), \
             tf.name_scope('Optimizer'):
            training_optimizer = optimizer_builder.build(
                train_config.optimizer, global_summaries)

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.train.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=train_config.worker_replicas)
            sync_optimizer = training_optimizer

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        '''
    if train_config.fine_tune_checkpoint:
      var_map = detection_model.restore_map(
        from_detection_checkpoint=train_config.from_detection_checkpoint
      )
      available_var_map = variables_helper.get_variables_available_in_checkpoint(
        var_map,
        train_config.fine_tune_checkpoint
      )
      init_saver = tf.train.Saver(available_var_map)
      def initializer_fn(sess):
        init_saver.restore(sess, train_config.fine_tune_checkpoint)
      init_fn = initializer_fn
     '''
        if train_config.fine_tune_checkpoint:
            all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            restore_vars = [
                var for var in all_vars
                if (var.name.split('/')[0] == 'FeatureExtractor'
                    and var.name.split('/')[1] == 'Convnet')
            ]
            pre_train_saver = tf.train.Saver(restore_vars)

            def load_pretrain(scaffold, sess):
                pre_train_saver.restore(sess,
                                        train_config.fine_tune_checkpoint)
        else:
            load_pretrain = None

        with tf.device(deploy_config.optimizer_device()), \
             tf.variable_scope('OptimizeClones'):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, training_optimizer, regularization_losses=None)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = [r'.*bias(?:es)?', r'.*beta']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = tf.contrib.training.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for (grad, var) in grads_and_vars:
            var_name = var.op.name
            grad_name = 'grad/' + var_name
            global_summaries.add(tf.summary.histogram(grad_name, grad))
            global_summaries.add(tf.summary.histogram(var_name, var))
        # for model_var in tf.contrib.framework.get_model_variables():
        #   global_summaries.add(tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        scaffold = tf.train.Scaffold(init_fn=load_pretrain,
                                     summary_op=summary_op,
                                     saver=saver)
        stop_hook = tf.train.StopAtStepHook(
            num_steps=(train_config.num_steps
                       if train_config.num_steps else None), )
        profile_hook = profile_session_run_hooks.ProfileAtStepHook(
            at_step=200, checkpoint_dir=train_dir)
        tf.contrib.training.train(
            train_tensor,
            train_dir,
            master=master,
            is_chief=is_chief,
            scaffold=scaffold,
            hooks=[stop_hook, profile_hook],
            chief_only_hooks=None,
            save_checkpoint_secs=train_config.save_checkpoint_secs,
            save_summaries_steps=train_config.save_summaries_steps,
            config=session_config)
示例#26
0
def evaluate(create_input_dict_fn, create_model_fn, eval_config,
             checkpoint_dir, eval_dir,
             repeat_evaluation=True):
  model = create_model_fn()
  data_preprocessing_steps = [
      preprocessor_builder.build(step)
      for step in eval_config.data_preprocessing_steps]

  tensor_dict = _extract_prediction_tensors(
      model=model,
      create_input_dict_fn=create_input_dict_fn,
      data_preprocessing_steps=data_preprocessing_steps,
      ignore_groundtruth=eval_config.ignore_groundtruth,
      evaluate_with_lexicon=eval_config.eval_with_lexicon)

  summary_writer = tf.summary.FileWriter(eval_dir)

  def _process_batch(tensor_dict, sess, batch_index, counters, update_op):
    if batch_index >= eval_config.num_visualizations:
      if 'original_image' in tensor_dict:
        tensor_dict = {k: v for (k, v) in tensor_dict.items()
                       if k != 'original_image'}
    try:
      (result_dict, _, glyphs) = sess.run([tensor_dict, update_op, tf.get_collection('glyph')])
      counters['success'] += 1
    except tf.errors.InvalidArgumentError:
      logging.info('Skipping image')
      counters['skipped'] += 1
      return {}
    global_step = tf.train.global_step(sess, tf.train.get_global_step())
    if batch_index < eval_config.num_visualizations:
      eval_util.visualize_recognition_results(
          result_dict,
          'Recognition_{}'.format(batch_index),
          global_step,
          summary_dir=eval_dir,
          export_dir=os.path.join(eval_dir, 'vis'),
          summary_writer=summary_writer,
          only_visualize_incorrect=eval_config.only_visualize_incorrect)

    return result_dict

  def _process_aggregated_results(result_lists):
    eval_metric_fn_key = eval_config.metrics_set
    if eval_metric_fn_key not in EVAL_METRICS_FN_DICT:
      raise ValueError('Metric not found: {}'.format(eval_metric_fn_key))
    return EVAL_METRICS_FN_DICT[eval_metric_fn_key](result_lists)

  variables_to_restore = tf.global_variables()
  #variables_to_restore = tf.trainable_variables()
  global_step = tf.train.get_or_create_global_step()
  variables_to_restore.append(global_step)
  if eval_config.use_moving_averages:
    variable_averages = tf.train.ExponentialMovingAverage(0.0)
    variables_to_restore = variable_averages.variables_to_restore()
  saver = tf.train.Saver(variables_to_restore)
  def _restore_latest_checkpoint(sess):
    latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
    saver.restore(sess, latest_checkpoint)

  eval_util.repeated_checkpoint_run(
      tensor_dict=tensor_dict,
      update_op=tf.no_op(),
      summary_dir=eval_dir,
      aggregated_result_processor=_process_aggregated_results,
      batch_processor=_process_batch,
      checkpoint_dirs=[checkpoint_dir],
      variables_to_restore=None,
      restore_fn=_restore_latest_checkpoint,
      num_batches=eval_config.num_examples,
      eval_interval_secs=eval_config.eval_interval_secs,
      max_number_of_evaluations=(
          1 if eval_config.ignore_groundtruth else
          eval_config.max_evals if eval_config.max_evals else
          None if repeat_evaluation else 1),
      master=eval_config.eval_master,
      save_graph=eval_config.save_graph,
      save_graph_dir=(eval_dir if eval_config.save_graph else ''))

  summary_writer.close()
示例#27
0
def train(create_tensor_dict_fn, create_model_fn, train_config, input_config, master, task,
          num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name,
          is_chief, train_dir, save_interval_secs=3600, log_every_n_steps=1000):
    """Training function for detection models.

    Args:
      create_tensor_dict_fn: a function to create a tensor input dictionary.
      create_model_fn: a function that creates a DetectionModel and generates
                       losses.
      train_config: a train_pb2.TrainConfig protobuf.
      input_config: a input_reader.InputReader protobuf.
      master: BNS name of the TensorFlow master to use.
      task: The task id of this training instance.
      num_clones: The number of clones to run per machine.
      worker_replicas: The number of work replicas to train with.
      clone_on_cpu: True if clones should be forced to run on CPU.
      ps_tasks: Number of parameter server tasks.
      worker_job_name: Name of the worker job.
      is_chief: Whether this replica is the chief replica.
      train_dir: Directory to write checkpoints and training summaries to.
      save_interval_secs: Interval in seconds to save a check point file.
      log_every_n_steps: The frequency, in terms of global steps, that the loss and global step are logged
    """

    detection_model = create_model_fn()

    preprocess_input_options = [
        preprocessor_input_builder.build(step)
        for step in input_config.preprocess_input_options]

    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        with tf.device(deploy_config.inputs_device()):
            input_queue = _create_input_queue(train_config.batch_size // num_clones,
                                              create_tensor_dict_fn,
                                              train_config.batch_queue_capacity,
                                              train_config.num_batch_queue_threads,
                                              train_config.prefetch_queue_capacity,
                                              data_augmentation_options,
                                              preprocess_input_options)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        if detection_model.is_rbbox:
            model_fn = functools.partial(_create_losses_rbbox,
                                         create_model_fn=create_model_fn)
        else:
            model_fn = functools.partial(_create_losses,
                                         create_model_fn=create_model_fn)
        clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue])
        first_clone_scope = clones[0].scope

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer = optimizer_builder.build(train_config.optimizer,
                                                         global_summaries)

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=train_config.worker_replicas)
            sync_optimizer = training_optimizer

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            var_map = detection_model.restore_map(
                from_detection_checkpoint=train_config.from_detection_checkpoint)
            available_var_map = (variables_helper.
                get_variables_available_in_checkpoint(
                var_map, train_config.fine_tune_checkpoint))
            init_saver = tf.train.Saver(available_var_map)

            def initializer_fn(sess):
                init_saver.restore(sess, train_config.fine_tune_checkpoint)

            init_fn = initializer_fn

        with tf.device(deploy_config.optimizer_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, training_optimizer, regularization_losses=None)
            total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(grads_and_vars,
                                                              global_step=global_step)
            update_ops.append(grad_updates)

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                           first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            max_to_keep=None,
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        slim.learning.train(
            train_tensor,
            logdir=train_dir,
            log_every_n_steps=log_every_n_steps,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps if train_config.num_steps else None),
            save_summaries_secs=240,
            save_interval_secs=save_interval_secs,
            sync_optimizer=sync_optimizer,
            saver=saver)
    def _train_input_fn(params=None):
        """Returns `features` and `labels` tensor dictionaries for training.

    Args:
      params: Parameter dictionary passed from the estimator.

    Returns:
      features: Dictionary of feature tensors.
        features[fields.InputDataFields.image] is a [batch_size, H, W, C]
          float32 tensor with preprocessed images.
        features[HASH_KEY] is a [batch_size] int32 tensor representing unique
          identifiers for the images.
        features[fields.InputDataFields.true_image_shape] is a [batch_size, 3]
          int32 tensor representing the true image shapes, as preprocessed
          images could be padded.
      labels: Dictionary of groundtruth tensors.
        labels[fields.InputDataFields.num_groundtruth_boxes] is a [batch_size]
          int32 tensor indicating the number of groundtruth boxes.
        labels[fields.InputDataFields.groundtruth_boxes] is a
          [batch_size, num_boxes, 4] float32 tensor containing the corners of
          the groundtruth boxes.
        labels[fields.InputDataFields.groundtruth_classes] is a
          [batch_size, num_boxes, num_classes] float32 one-hot tensor of
          classes.
        labels[fields.InputDataFields.groundtruth_weights] is a
          [batch_size, num_boxes] float32 tensor containing groundtruth weights
          for the boxes.
        -- Optional --
        labels[fields.InputDataFields.groundtruth_instance_masks] is a
          [batch_size, num_boxes, H, W] float32 tensor containing only binary
          values, which represent instance masks for objects.
        labels[fields.InputDataFields.groundtruth_keypoints] is a
          [batch_size, num_boxes, num_keypoints, 2] float32 tensor containing
          keypoints for each box.

    Raises:
      TypeError: if the `train_config` or `train_input_config` are not of the
        correct type.
    """
        if not isinstance(train_config, train_pb2.TrainConfig):
            raise TypeError('For training mode, the `train_config` must be a '
                            'train_pb2.TrainConfig.')
        if not isinstance(train_input_config, input_reader_pb2.InputReader):
            raise TypeError('The `train_input_config` must be a '
                            'input_reader_pb2.InputReader.')
        if not isinstance(model_config, model_pb2.DetectionModel):
            raise TypeError('The `model_config` must be a '
                            'model_pb2.DetectionModel.')

        data_augmentation_options = [
            preprocessor_builder.build(step)
            for step in train_config.data_augmentation_options
        ]
        data_augmentation_fn = functools.partial(
            augment_input_data,
            data_augmentation_options=data_augmentation_options)

        model = model_builder.build(model_config, is_training=True)
        image_resizer_config = config_util.get_image_resizer_config(
            model_config)
        image_resizer_fn = image_resizer_builder.build(image_resizer_config)

        transform_data_fn = functools.partial(
            transform_input_data,
            model_preprocess_fn=model.preprocess,
            image_resizer_fn=image_resizer_fn,
            num_classes=config_util.get_number_of_classes(model_config),
            data_augmentation_fn=data_augmentation_fn)
        dataset = dataset_builder.build(
            train_input_config,
            transform_input_data_fn=transform_data_fn,
            batch_size=params['batch_size']
            if params else train_config.batch_size,
            max_num_boxes=train_config.max_number_of_boxes,
            num_classes=config_util.get_number_of_classes(model_config),
            spatial_image_shape=config_util.get_spatial_image_size(
                image_resizer_config))
        tensor_dict = dataset_util.make_initializable_iterator(
            dataset).get_next()

        hash_from_source_id = tf.string_to_hash_bucket_fast(
            tensor_dict[fields.InputDataFields.source_id], HASH_BINS)
        features = {
            fields.InputDataFields.image:
            tensor_dict[fields.InputDataFields.image],
            HASH_KEY:
            tf.cast(hash_from_source_id, tf.int32),
            fields.InputDataFields.true_image_shape:
            tensor_dict[fields.InputDataFields.true_image_shape]
        }

        labels = {
            fields.InputDataFields.num_groundtruth_boxes:
            tensor_dict[fields.InputDataFields.num_groundtruth_boxes],
            fields.InputDataFields.groundtruth_boxes:
            tensor_dict[fields.InputDataFields.groundtruth_boxes],
            fields.InputDataFields.groundtruth_classes:
            tensor_dict[fields.InputDataFields.groundtruth_classes],
            fields.InputDataFields.groundtruth_weights:
            tensor_dict[fields.InputDataFields.groundtruth_weights]
        }
        if fields.InputDataFields.groundtruth_keypoints in tensor_dict:
            labels[fields.InputDataFields.groundtruth_keypoints] = tensor_dict[
                fields.InputDataFields.groundtruth_keypoints]
        if fields.InputDataFields.groundtruth_instance_masks in tensor_dict:
            labels[fields.InputDataFields.
                   groundtruth_instance_masks] = tensor_dict[
                       fields.InputDataFields.groundtruth_instance_masks]

        return features, labels