def test_build_random_crop_pad_image_with_optional_parameters(self):
   preprocessor_text_proto = """
   random_crop_pad_image {
     min_object_covered: 0.75
     min_aspect_ratio: 0.75
     max_aspect_ratio: 1.5
     min_area: 0.25
     max_area: 0.875
     overlap_thresh: 0.5
     clip_boxes: False
     random_coef: 0.125
     min_padded_size_ratio: 0.5
     min_padded_size_ratio: 0.75
     max_padded_size_ratio: 0.5
     max_padded_size_ratio: 0.75
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_crop_pad_image)
   self.assertEqual(args, {
       'min_object_covered': 0.75,
       'aspect_ratio_range': (0.75, 1.5),
       'area_range': (0.25, 0.875),
       'overlap_thresh': 0.5,
       'clip_boxes': False,
       'random_coef': 0.125,
       'min_padded_size_ratio': (0.5, 0.75),
       'max_padded_size_ratio': (0.5, 0.75),
       'pad_color': None,
   })
 def test_build_ssd_random_crop_fixed_aspect_ratio(self):
   preprocessor_text_proto = """
   ssd_random_crop_fixed_aspect_ratio {
     operations {
       min_object_covered: 0.0
       min_area: 0.5
       max_area: 1.0
       overlap_thresh: 0.0
       clip_boxes: False
       random_coef: 0.375
     }
     operations {
       min_object_covered: 0.25
       min_area: 0.5
       max_area: 1.0
       overlap_thresh: 0.25
       clip_boxes: True
       random_coef: 0.375
     }
     aspect_ratio: 0.875
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.ssd_random_crop_fixed_aspect_ratio)
   self.assertEqual(args, {'min_object_covered': [0.0, 0.25],
                           'aspect_ratio': 0.875,
                           'area_range': [(0.5, 1.0), (0.5, 1.0)],
                           'overlap_thresh': [0.0, 0.25],
                           'clip_boxes': [False, True],
                           'random_coef': [0.375, 0.375]})
Exemplo n.º 3
0
  def transform_and_pad_input_data_fn(tensor_dict):
    """Combines transform and pad operation."""
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]
    data_augmentation_fn = functools.partial(
        augment_input_data,
        data_augmentation_options=data_augmentation_options)

    image_resizer_config = config_util.get_image_resizer_config(model_config)
    image_resizer_fn = image_resizer_builder.build(image_resizer_config)
    transform_data_fn = functools.partial(
        transform_input_data, model_preprocess_fn=model_preprocess_fn,
        image_resizer_fn=image_resizer_fn,
        num_classes=config_util.get_number_of_classes(model_config),
        data_augmentation_fn=data_augmentation_fn,
        merge_multiple_boxes=train_config.merge_multiple_label_boxes,
        retain_original_image=train_config.retain_original_images,
        use_multiclass_scores=train_config.use_multiclass_scores,
        use_bfloat16=train_config.use_bfloat16)

    tensor_dict = pad_input_data_to_static_shapes(
        tensor_dict=transform_data_fn(tensor_dict),
        max_num_boxes=train_input_config.max_number_of_boxes,
        num_classes=config_util.get_number_of_classes(model_config),
        spatial_image_shape=config_util.get_spatial_image_size(
            image_resizer_config))
    return (_get_features_dict(tensor_dict), _get_labels_dict(tensor_dict))
 def test_build_rgb_to_gray(self):
   preprocessor_text_proto = """
   rgb_to_gray {}
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.rgb_to_gray)
   self.assertEqual(args, {})
 def test_build_scale_boxes_to_pixel_coordinates(self):
   preprocessor_text_proto = """
   scale_boxes_to_pixel_coordinates {}
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.scale_boxes_to_pixel_coordinates)
   self.assertEqual(args, {})
 def test_build_random_rotation90(self):
   preprocessor_text_proto = """
   random_rotation90 {}
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_rotation90)
   self.assertEqual(args, {})
 def test_build_ssd_random_crop_empty_operations(self):
   preprocessor_text_proto = """
   ssd_random_crop {
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.ssd_random_crop)
   self.assertEqual(args, {})
 def test_auto_augment_image(self):
   preprocessor_text_proto = """
   autoaugment_image {
     policy_name: 'v0'
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.autoaugment_image)
   self.assert_dictionary_close(args, {'policy_name': 'v0'})
 def test_build_subtract_channel_mean(self):
   preprocessor_text_proto = """
   subtract_channel_mean {
     means: [1.0, 2.0, 3.0]
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.subtract_channel_mean)
   self.assertEqual(args, {'means': [1.0, 2.0, 3.0]})
Exemplo n.º 10
0
 def test_build_random_jitter_boxes(self):
   preprocessor_text_proto = """
   random_jitter_boxes {
     ratio: 0.1
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_jitter_boxes)
   self.assert_dictionary_close(args, {'ratio': 0.1})
Exemplo n.º 11
0
 def test_build_random_distort_color(self):
   preprocessor_text_proto = """
   random_distort_color {
     color_ordering: 1
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_distort_color)
   self.assertEqual(args, {'color_ordering': 1})
Exemplo n.º 12
0
 def test_build_random_adjust_hue(self):
   preprocessor_text_proto = """
   random_adjust_hue {
     max_delta: 0.01
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_adjust_hue)
   self.assert_dictionary_close(args, {'max_delta': 0.01})
Exemplo n.º 13
0
 def test_build_random_rgb_to_gray(self):
   preprocessor_text_proto = """
   random_rgb_to_gray {
     probability: 0.8
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_rgb_to_gray)
   self.assert_dictionary_close(args, {'probability': 0.8})
Exemplo n.º 14
0
 def test_build_normalize_image_convert_class_logits_to_softmax(self):
   preprocessor_text_proto = """
   convert_class_logits_to_softmax {
       temperature: 2
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.convert_class_logits_to_softmax)
   self.assertEqual(args, {'temperature': 2})
Exemplo n.º 15
0
 def test_build_random_adjust_saturation(self):
   preprocessor_text_proto = """
   random_adjust_saturation {
     min_delta: 0.75
     max_delta: 1.15
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_adjust_saturation)
   self.assert_dictionary_close(args, {'min_delta': 0.75, 'max_delta': 1.15})
Exemplo n.º 16
0
 def test_build_random_resize_method(self):
   preprocessor_text_proto = """
   random_resize_method {
     target_height: 75
     target_width: 100
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_resize_method)
   self.assert_dictionary_close(args, {'target_size': [75, 100]})
Exemplo n.º 17
0
 def test_build_random_pixel_value_scale(self):
   preprocessor_text_proto = """
   random_pixel_value_scale {
     minval: 0.8
     maxval: 1.2
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_pixel_value_scale)
   self.assert_dictionary_close(args, {'minval': 0.8, 'maxval': 1.2})
Exemplo n.º 18
0
 def test_random_self_concat_image(self):
   preprocessor_text_proto = """
   random_self_concat_image {
     concat_vertical_probability: 0.5
     concat_horizontal_probability: 0.25
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_self_concat_image)
   self.assertEqual(args, {'concat_vertical_probability': 0.5,
                           'concat_horizontal_probability': 0.25})
Exemplo n.º 19
0
 def test_build_random_pad_image(self):
   preprocessor_text_proto = """
   random_pad_image {
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_pad_image)
   self.assertEqual(args, {
       'min_image_size': None,
       'max_image_size': None,
       'pad_color': None,
   })
Exemplo n.º 20
0
 def test_build_random_black_patches(self):
   preprocessor_text_proto = """
   random_black_patches {
     max_black_patches: 20
     probability: 0.95
     size_to_image_ratio: 0.12
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_black_patches)
   self.assert_dictionary_close(args, {'max_black_patches': 20,
                                       'probability': 0.95,
                                       'size_to_image_ratio': 0.12})
Exemplo n.º 21
0
 def test_build_random_crop_to_aspect_ratio(self):
   preprocessor_text_proto = """
   random_crop_to_aspect_ratio {
     aspect_ratio: 0.85
     overlap_thresh: 0.35
     clip_boxes: False
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_crop_to_aspect_ratio)
   self.assert_dictionary_close(args, {'aspect_ratio': 0.85,
                                       'overlap_thresh': 0.35,
                                       'clip_boxes': False})
Exemplo n.º 22
0
 def test_build_resize_image(self):
   preprocessor_text_proto = """
   resize_image {
     new_height: 75
     new_width: 100
     method: BICUBIC
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.resize_image)
   self.assertEqual(args, {'new_height': 75,
                           'new_width': 100,
                           'method': tf.image.ResizeMethod.BICUBIC})
Exemplo n.º 23
0
 def test_drop_label_probabilistically(self):
   preprocessor_text_proto = """
   drop_label_probabilistically{
     label: 2
     drop_probability: 0.5
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.drop_label_probabilistically)
   self.assert_dictionary_close(args, {
       'dropped_label': 2,
       'drop_probability': 0.5
   })
Exemplo n.º 24
0
 def test_build_ssd_random_crop_pad(self):
   preprocessor_text_proto = """
   ssd_random_crop_pad {
     operations {
       min_object_covered: 0.0
       min_aspect_ratio: 0.875
       max_aspect_ratio: 1.125
       min_area: 0.5
       max_area: 1.0
       overlap_thresh: 0.0
       clip_boxes: False
       random_coef: 0.375
       min_padded_size_ratio: [1.0, 1.0]
       max_padded_size_ratio: [2.0, 2.0]
       pad_color_r: 0.5
       pad_color_g: 0.5
       pad_color_b: 0.5
     }
     operations {
       min_object_covered: 0.25
       min_aspect_ratio: 0.75
       max_aspect_ratio: 1.5
       min_area: 0.5
       max_area: 1.0
       overlap_thresh: 0.25
       clip_boxes: True
       random_coef: 0.375
       min_padded_size_ratio: [1.0, 1.0]
       max_padded_size_ratio: [2.0, 2.0]
       pad_color_r: 0.5
       pad_color_g: 0.5
       pad_color_b: 0.5
     }
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.ssd_random_crop_pad)
   self.assertEqual(args, {'min_object_covered': [0.0, 0.25],
                           'aspect_ratio_range': [(0.875, 1.125), (0.75, 1.5)],
                           'area_range': [(0.5, 1.0), (0.5, 1.0)],
                           'overlap_thresh': [0.0, 0.25],
                           'clip_boxes': [False, True],
                           'random_coef': [0.375, 0.375],
                           'min_padded_size_ratio': [(1.0, 1.0), (1.0, 1.0)],
                           'max_padded_size_ratio': [(2.0, 2.0), (2.0, 2.0)],
                           'pad_color': [(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)]})
Exemplo n.º 25
0
 def test_build_random_vertical_flip(self):
   preprocessor_text_proto = """
   random_vertical_flip {
     keypoint_flip_permutation: 1
     keypoint_flip_permutation: 0
     keypoint_flip_permutation: 2
     keypoint_flip_permutation: 3
     keypoint_flip_permutation: 5
     keypoint_flip_permutation: 4
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_vertical_flip)
   self.assertEqual(args, {'keypoint_flip_permutation': (1, 0, 2, 3, 5, 4)})
Exemplo n.º 26
0
 def test_build_random_absolute_pad_image(self):
   preprocessor_text_proto = """
   random_absolute_pad_image {
     max_height_padding: 50
     max_width_padding: 100
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_absolute_pad_image)
   self.assertEqual(args, {
       'max_height_padding': 50,
       'max_width_padding': 100,
       'pad_color': None,
   })
Exemplo n.º 27
0
 def test_remap_labels(self):
   preprocessor_text_proto = """
   remap_labels{
     original_labels: 1
     original_labels: 2
     new_label: 3
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.remap_labels)
   self.assert_dictionary_close(args, {
       'original_labels': [1, 2],
       'new_label': 3
   })
Exemplo n.º 28
0
 def test_build_normalize_image(self):
   preprocessor_text_proto = """
   normalize_image {
     original_minval: 0.0
     original_maxval: 255.0
     target_minval: -1.0
     target_maxval: 1.0
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.normalize_image)
   self.assertEqual(args, {
       'original_minval': 0.0,
       'original_maxval': 255.0,
       'target_minval': -1.0,
       'target_maxval': 1.0,
   })
Exemplo n.º 29
0
def train(create_tensor_dict_fn,
          create_model_fn,
          train_config,
          master,
          task,
          num_clones,
          worker_replicas,
          clone_on_cpu,
          ps_tasks,
          worker_job_name,
          is_chief,
          train_dir,
          graph_hook_fn=None):
    """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
    graph_hook_fn: Optional function that is called after the inference graph is
      built (before optimization). This is helpful to perform additional changes
      to the training graph such as adding FakeQuant ops. The function should
      modify the default graph.

  Raises:
    ValueError: If both num_clones > 1 and train_config.sync_replicas is true.
  """

    detection_model = create_model_fn()
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        if num_clones != 1 and train_config.sync_replicas:
            raise ValueError('In Synchronous SGD mode num_clones must ',
                             'be 1. Found num_clones: {}'.format(num_clones))
        batch_size = train_config.batch_size // num_clones
        if train_config.sync_replicas:
            batch_size //= train_config.replicas_to_aggregate

        with tf.device(deploy_config.inputs_device()):
            input_queue = create_input_queue(
                batch_size, create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options)

        # Gather initial summaries.
        # TODO(rathodv): See if summaries can be added/extracted from global tf
        # collections so that they don't have to be passed around.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        model_fn = functools.partial(_create_losses,
                                     create_model_fn=create_model_fn,
                                     train_config=train_config)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue])
        first_clone_scope = clones[0].scope

        if graph_hook_fn:
            with tf.device(deploy_config.variables_device()):
                graph_hook_fn()

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                train_config.optimizer)
            for var in optimizer_summary_vars:
                tf.summary.scalar(var.op.name, var, family='LearningRate')

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.train.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=worker_replicas)
            sync_optimizer = training_optimizer

        with tf.device(deploy_config.optimizer_device()):
            regularization_losses = (
                None if train_config.add_regularization_loss else [])
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones,
                training_optimizer,
                regularization_losses=regularization_losses)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops, name='update_barrier')
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram('ModelVars/' + model_var.op.name,
                                     model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar('Losses/' + loss_tensor.op.name,
                                  loss_tensor))
        global_summaries.add(
            tf.summary.scalar('Losses/TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            if not train_config.fine_tune_checkpoint_type:
                # train_config.from_detection_checkpoint field is deprecated. For
                # backward compatibility, fine_tune_checkpoint_type is set based on
                # from_detection_checkpoint.
                if train_config.from_detection_checkpoint:
                    train_config.fine_tune_checkpoint_type = 'detection'
                else:
                    train_config.fine_tune_checkpoint_type = 'classification'
            var_map = detection_model.restore_map(
                fine_tune_checkpoint_type=train_config.
                fine_tune_checkpoint_type,
                load_all_detection_checkpoint_vars=(
                    train_config.load_all_detection_checkpoint_vars))
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map,
                    train_config.fine_tune_checkpoint,
                    include_global_step=False))
            init_saver = tf.train.Saver(available_var_map)

            def initializer_fn(sess):
                init_saver.restore(sess, train_config.fine_tune_checkpoint)

            init_fn = initializer_fn

        slim.learning.train(
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            save_summaries_secs=120,
            sync_optimizer=sync_optimizer,
            saver=saver)