Esempio n. 1
0
    def transform_and_pad_input_data_fn(tensor_dict):
      """Combines transform and pad operation."""
      data_augmentation_options = [
          preprocessor_builder.build(step)
          for step in train_config.data_augmentation_options
      ]
      data_augmentation_fn = functools.partial(
          augment_input_data,
          data_augmentation_options=data_augmentation_options)
      model = model_builder.build(model_config, is_training=True)
      image_resizer_config = config_util.get_image_resizer_config(model_config)
      image_resizer_fn = image_resizer_builder.build(image_resizer_config)
      transform_data_fn = functools.partial(
          transform_input_data, model_preprocess_fn=model.preprocess,
          image_resizer_fn=image_resizer_fn,
          num_classes=config_util.get_number_of_classes(model_config),
          data_augmentation_fn=data_augmentation_fn,
          merge_multiple_boxes=train_config.merge_multiple_label_boxes,
          retain_original_image=train_config.retain_original_images,
          use_bfloat16=train_config.use_bfloat16)

      tensor_dict = pad_input_data_to_static_shapes(
          tensor_dict=transform_data_fn(tensor_dict),
          max_num_boxes=train_input_config.max_number_of_boxes,
          num_classes=config_util.get_number_of_classes(model_config),
          spatial_image_shape=config_util.get_spatial_image_size(
              image_resizer_config))
      return (_get_features_dict(tensor_dict), _get_labels_dict(tensor_dict))
 def test_build_ssd_random_crop(self):
   preprocessor_text_proto = """
   ssd_random_crop {
     operations {
       min_object_covered: 0.0
       min_aspect_ratio: 0.875
       max_aspect_ratio: 1.125
       min_area: 0.5
       max_area: 1.0
       overlap_thresh: 0.0
       clip_boxes: False
       random_coef: 0.375
     }
     operations {
       min_object_covered: 0.25
       min_aspect_ratio: 0.75
       max_aspect_ratio: 1.5
       min_area: 0.5
       max_area: 1.0
       overlap_thresh: 0.25
       clip_boxes: True
       random_coef: 0.375
     }
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.ssd_random_crop)
   self.assertEqual(args, {'min_object_covered': [0.0, 0.25],
                           'aspect_ratio_range': [(0.875, 1.125), (0.75, 1.5)],
                           'area_range': [(0.5, 1.0), (0.5, 1.0)],
                           'overlap_thresh': [0.0, 0.25],
                           'clip_boxes': [False, True],
                           'random_coef': [0.375, 0.375]})
 def test_build_random_crop_pad_image(self):
   preprocessor_text_proto = """
   random_crop_pad_image {
     min_object_covered: 0.75
     min_aspect_ratio: 0.75
     max_aspect_ratio: 1.5
     min_area: 0.25
     max_area: 0.875
     overlap_thresh: 0.5
     clip_boxes: False
     random_coef: 0.125
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_crop_pad_image)
   self.assertEqual(args, {
       'min_object_covered': 0.75,
       'aspect_ratio_range': (0.75, 1.5),
       'area_range': (0.25, 0.875),
       'overlap_thresh': 0.5,
       'clip_boxes': False,
       'random_coef': 0.125,
       'pad_color': None,
   })
 def test_build_random_crop_pad_image_with_optional_parameters(self):
   preprocessor_text_proto = """
   random_crop_pad_image {
     min_object_covered: 0.75
     min_aspect_ratio: 0.75
     max_aspect_ratio: 1.5
     min_area: 0.25
     max_area: 0.875
     overlap_thresh: 0.5
     random_coef: 0.125
     min_padded_size_ratio: 0.5
     min_padded_size_ratio: 0.75
     max_padded_size_ratio: 0.5
     max_padded_size_ratio: 0.75
     pad_color: 0.5
     pad_color: 0.5
     pad_color: 1.0
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_crop_pad_image)
   self.assertEqual(args, {
       'min_object_covered': 0.75,
       'aspect_ratio_range': (0.75, 1.5),
       'area_range': (0.25, 0.875),
       'overlap_thresh': 0.5,
       'random_coef': 0.125,
       'min_padded_size_ratio': (0.5, 0.75),
       'max_padded_size_ratio': (0.5, 0.75),
       'pad_color': (0.5, 0.5, 1.0)
   })
 def test_build_ssd_random_crop_fixed_aspect_ratio(self):
   preprocessor_text_proto = """
   ssd_random_crop_fixed_aspect_ratio {
     operations {
       min_object_covered: 0.0
       min_area: 0.5
       max_area: 1.0
       overlap_thresh: 0.0
       random_coef: 0.375
     }
     operations {
       min_object_covered: 0.25
       min_area: 0.5
       max_area: 1.0
       overlap_thresh: 0.25
       random_coef: 0.375
     }
     aspect_ratio: 0.875
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.ssd_random_crop_fixed_aspect_ratio)
   self.assertEqual(args, {'min_object_covered': [0.0, 0.25],
                           'aspect_ratio': 0.875,
                           'area_range': [(0.5, 1.0), (0.5, 1.0)],
                           'overlap_thresh': [0.0, 0.25],
                           'random_coef': [0.375, 0.375]})
 def test_build_random_rotation90(self):
   preprocessor_text_proto = """
   random_rotation90 {}
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_rotation90)
   self.assertEqual(args, {})
 def test_build_scale_boxes_to_pixel_coordinates(self):
   preprocessor_text_proto = """
   scale_boxes_to_pixel_coordinates {}
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.scale_boxes_to_pixel_coordinates)
   self.assertEqual(args, {})
 def test_build_rgb_to_gray(self):
   preprocessor_text_proto = """
   rgb_to_gray {}
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.rgb_to_gray)
   self.assertEqual(args, {})
 def test_build_ssd_random_crop_empty_operations(self):
   preprocessor_text_proto = """
   ssd_random_crop {
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.ssd_random_crop)
   self.assertEqual(args, {})
Esempio n. 10
0
 def test_build_random_distort_color(self):
   preprocessor_text_proto = """
   random_distort_color {
     color_ordering: 1
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_distort_color)
   self.assertEqual(args, {'color_ordering': 1})
Esempio n. 11
0
 def test_build_random_jitter_boxes(self):
   preprocessor_text_proto = """
   random_jitter_boxes {
     ratio: 0.1
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_jitter_boxes)
   self.assert_dictionary_close(args, {'ratio': 0.1})
Esempio n. 12
0
 def test_build_random_rgb_to_gray(self):
   preprocessor_text_proto = """
   random_rgb_to_gray {
     probability: 0.8
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_rgb_to_gray)
   self.assert_dictionary_close(args, {'probability': 0.8})
Esempio n. 13
0
 def test_build_subtract_channel_mean(self):
   preprocessor_text_proto = """
   subtract_channel_mean {
     means: [1.0, 2.0, 3.0]
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.subtract_channel_mean)
   self.assertEqual(args, {'means': [1.0, 2.0, 3.0]})
 def test_build_normalize_image_convert_class_logits_to_softmax(self):
   preprocessor_text_proto = """
   convert_class_logits_to_softmax {
       temperature: 2
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.convert_class_logits_to_softmax)
   self.assertEqual(args, {'temperature': 2})
Esempio n. 15
0
 def test_build_random_adjust_hue(self):
   preprocessor_text_proto = """
   random_adjust_hue {
     max_delta: 0.01
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_adjust_hue)
   self.assert_dictionary_close(args, {'max_delta': 0.01})
Esempio n. 16
0
 def test_build_random_pixel_value_scale(self):
   preprocessor_text_proto = """
   random_pixel_value_scale {
     minval: 0.8
     maxval: 1.2
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_pixel_value_scale)
   self.assert_dictionary_close(args, {'minval': 0.8, 'maxval': 1.2})
Esempio n. 17
0
 def get_next(config, model_config, lstm_config, unroll_length):
   data_augmentation_options = [
       preprocessor_builder.build(step)
       for step in train_config.data_augmentation_options
   ]
   return seq_dataset_builder.build(
       config,
       model_config,
       lstm_config,
       unroll_length,
       data_augmentation_options,
       batch_size=train_config.batch_size)
Esempio n. 18
0
 def test_build_random_adjust_saturation(self):
   preprocessor_text_proto = """
   random_adjust_saturation {
     min_delta: 0.75
     max_delta: 1.15
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_adjust_saturation)
   self.assert_dictionary_close(args, {'min_delta': 0.75, 'max_delta': 1.15})
Esempio n. 19
0
 def test_build_random_resize_method(self):
   preprocessor_text_proto = """
   random_resize_method {
     target_height: 75
     target_width: 100
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_resize_method)
   self.assert_dictionary_close(args, {'target_size': [75, 100]})
Esempio n. 20
0
 def test_build_random_crop_to_aspect_ratio(self):
   preprocessor_text_proto = """
   random_crop_to_aspect_ratio {
     aspect_ratio: 0.85
     overlap_thresh: 0.35
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_crop_to_aspect_ratio)
   self.assert_dictionary_close(args, {'aspect_ratio': 0.85,
                                       'overlap_thresh': 0.35})
 def test_random_self_concat_image(self):
   preprocessor_text_proto = """
   random_self_concat_image {
     concat_vertical_probability: 0.5
     concat_horizontal_probability: 0.25
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_self_concat_image)
   self.assertEqual(args, {'concat_vertical_probability': 0.5,
                           'concat_horizontal_probability': 0.25})
Esempio n. 22
0
 def test_build_random_pad_image(self):
   preprocessor_text_proto = """
   random_pad_image {
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_pad_image)
   self.assertEqual(args, {
       'min_image_size': None,
       'max_image_size': None,
       'pad_color': None,
   })
Esempio n. 23
0
 def test_build_resize_image(self):
   preprocessor_text_proto = """
   resize_image {
     new_height: 75
     new_width: 100
     method: BICUBIC
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.resize_image)
   self.assertEqual(args, {'new_height': 75,
                           'new_width': 100,
                           'method': tf.image.ResizeMethod.BICUBIC})
Esempio n. 24
0
 def test_build_random_black_patches(self):
   preprocessor_text_proto = """
   random_black_patches {
     max_black_patches: 20
     probability: 0.95
     size_to_image_ratio: 0.12
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_black_patches)
   self.assert_dictionary_close(args, {'max_black_patches': 20,
                                       'probability': 0.95,
                                       'size_to_image_ratio': 0.12})
 def test_build_random_absolute_pad_image(self):
   preprocessor_text_proto = """
   random_absolute_pad_image {
     max_height_padding: 50
     max_width_padding: 100
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_absolute_pad_image)
   self.assertEqual(args, {
       'max_height_padding': 50,
       'max_width_padding': 100,
       'pad_color': None,
   })
Esempio n. 26
0
 def test_build_random_vertical_flip(self):
   preprocessor_text_proto = """
   random_vertical_flip {
     keypoint_flip_permutation: 1
     keypoint_flip_permutation: 0
     keypoint_flip_permutation: 2
     keypoint_flip_permutation: 3
     keypoint_flip_permutation: 5
     keypoint_flip_permutation: 4
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.random_vertical_flip)
   self.assertEqual(args, {'keypoint_flip_permutation': (1, 0, 2, 3, 5, 4)})
Esempio n. 27
0
 def test_build_ssd_random_crop_pad(self):
   preprocessor_text_proto = """
   ssd_random_crop_pad {
     operations {
       min_object_covered: 0.0
       min_aspect_ratio: 0.875
       max_aspect_ratio: 1.125
       min_area: 0.5
       max_area: 1.0
       overlap_thresh: 0.0
       random_coef: 0.375
       min_padded_size_ratio: [1.0, 1.0]
       max_padded_size_ratio: [2.0, 2.0]
       pad_color_r: 0.5
       pad_color_g: 0.5
       pad_color_b: 0.5
     }
     operations {
       min_object_covered: 0.25
       min_aspect_ratio: 0.75
       max_aspect_ratio: 1.5
       min_area: 0.5
       max_area: 1.0
       overlap_thresh: 0.25
       random_coef: 0.375
       min_padded_size_ratio: [1.0, 1.0]
       max_padded_size_ratio: [2.0, 2.0]
       pad_color_r: 0.5
       pad_color_g: 0.5
       pad_color_b: 0.5
     }
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.ssd_random_crop_pad)
   self.assertEqual(args, {'min_object_covered': [0.0, 0.25],
                           'aspect_ratio_range': [(0.875, 1.125), (0.75, 1.5)],
                           'area_range': [(0.5, 1.0), (0.5, 1.0)],
                           'overlap_thresh': [0.0, 0.25],
                           'random_coef': [0.375, 0.375],
                           'min_padded_size_ratio': [(1.0, 1.0), (1.0, 1.0)],
                           'max_padded_size_ratio': [(2.0, 2.0), (2.0, 2.0)],
                           'pad_color': [(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)]})
Esempio n. 28
0
 def test_build_normalize_image(self):
   preprocessor_text_proto = """
   normalize_image {
     original_minval: 0.0
     original_maxval: 255.0
     target_minval: -1.0
     target_maxval: 1.0
   }
   """
   preprocessor_proto = preprocessor_pb2.PreprocessingStep()
   text_format.Merge(preprocessor_text_proto, preprocessor_proto)
   function, args = preprocessor_builder.build(preprocessor_proto)
   self.assertEqual(function, preprocessor.normalize_image)
   self.assertEqual(args, {
       'original_minval': 0.0,
       'original_maxval': 255.0,
       'target_minval': -1.0,
       'target_maxval': 1.0,
   })
  def test_build_with_data_augmentation(self):
    input_reader_proto = input_reader_pb2.InputReader()
    text_format.Merge(
        self._get_input_proto('tf_record_video_input_reader'),
        input_reader_proto)

    configs = self._get_model_configs_from_proto()
    data_augmentation_options = [
        preprocessor_builder.build(
            self._get_data_augmentation_preprocessor_proto())
    ]
    tensor_dict = seq_dataset_builder.build(
        input_reader_proto,
        configs['model'],
        configs['lstm_model'],
        unroll_length=1,
        data_augmentation_options=data_augmentation_options)

    all_dict = self._create_training_dict(tensor_dict)
    self.assertEqual((1, 32, 32, 3), all_dict['image0'].shape)
    self.assertEqual(4, all_dict['groundtruth_boxes0'].shape[1])
def main(_):
  assert FLAGS.train_dir, '`train_dir` is missing.'
  if FLAGS.task == 0: tf.gfile.MakeDirs(FLAGS.train_dir)
  if FLAGS.pipeline_config_path:
    configs = config_util.get_configs_from_pipeline_file(
        FLAGS.pipeline_config_path)
    if FLAGS.task == 0:
      tf.gfile.Copy(FLAGS.pipeline_config_path,
                    os.path.join(FLAGS.train_dir, 'pipeline.config'),
                    overwrite=True)
  else:
    configs = config_util.get_configs_from_multiple_files(
        model_config_path=FLAGS.model_config_path,
        train_config_path=FLAGS.train_config_path,
        train_input_config_path=FLAGS.input_config_path)
    if FLAGS.task == 0:
      for name, config in [('model.config', FLAGS.model_config_path),
                           ('train.config', FLAGS.train_config_path),
                           ('input.config', FLAGS.input_config_path)]:
        tf.gfile.Copy(config, os.path.join(FLAGS.train_dir, name),
                      overwrite=True)

  model_config = configs['model']
  train_config = configs['train_config']
  input_config = configs['train_input_config']

  model_fn = functools.partial(
      model_builder.build,
      model_config=model_config,
      is_training=True)
  
  #iterator = dataset_util.make_initializable_iterator(dataset_builder.build(input_config))
  datasetmy = dataset_builder.build(input_config)
  iterator = datasetmy.make_initializable_iterator()
  
  def get_next(config):
    return iterator.get_next()

  create_input_dict_fn = functools.partial(get_next, input_config)

  
  data_augmentation_options = [
      preprocessor_builder.build(step)
      for step in train_config.data_augmentation_options]
  
  input_queue = trainer.create_input_queue(
      train_config.batch_size, create_input_dict_fn,
      train_config.batch_queue_capacity,
      train_config.num_batch_queue_threads,
      train_config.prefetch_queue_capacity, data_augmentation_options)
  
  tensors = input_queue.dequeue()

  #print all tensors in tfrecord
  print(tensors)
  
  groundtruth_difficult = tensors[0]['groundtruth_difficult']
  groundtruth_group_of = tensors[0]['groundtruth_group_of']
  groundtruth_weights = tensors[0]['groundtruth_weights']
  groundtruth_is_crowd = tensors[0]['groundtruth_is_crowd']
  key = tensors[0]['key']
  groundtruth_boxes = tensors[0]['groundtruth_boxes']
  image = tensors[0]['image']
  groundtruth_area = tensors[0]['groundtruth_area']
  groundtruth_classes = tensors[0]['groundtruth_classes']
  filename = tensors[0]['filename']
  num_groundtruth_boxes = tensors[0]['num_groundtruth_boxes']
  source_id = tensors[0]['source_id']
  
  
  
   
  init_op=tf.initialize_all_variables()
  with tf.Session() as sess:
    sess.run(iterator.initializer)
    sess.run(tf.tables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    sess.run(init_op)
    for i in range(10):
      groundtruth_weights_val,groundtruth_difficult_val,groundtruth_group_of_val,groundtruth_is_crowd_val,key_val,groundtruth_boxes_val,image_val,groundtruth_area_val,groundtruth_classes_val,filename_val,num_groundtruth_boxes_val,source_id_val = \
      sess.run([groundtruth_weights,groundtruth_difficult,groundtruth_group_of,groundtruth_is_crowd,key,groundtruth_boxes,image,groundtruth_area,groundtruth_classes,filename,num_groundtruth_boxes,source_id])
#       print(groundtruth_weights_val)
      print(groundtruth_boxes_val)
#       print(groundtruth_difficult_val)
#       print(groundtruth_group_of_val)
#       print(groundtruth_is_crowd_val)
#       print(key_val)
#       print(image_val)
#       print(groundtruth_area_val)
      print(groundtruth_classes_val)
      print(filename_val)
      print(num_groundtruth_boxes_val)
#       print(source_id_val)
      image_val = image_val[0]
      image_val = image_val.astype(np.uint8)
#       cv2.imshow('image', image_val)
#       cv2.waitKey()
#       plt.imshow(image_val)
#       plt.show()  
      print('finish')
      
      #plot bbox on image
      plt.switch_backend("TkAgg")
      classes_val = groundtruth_classes_val
      boxes_val = groundtruth_boxes_val
      scores_val = [1.0]*num_groundtruth_boxes_val
      image_np = image_val
      image_np_origin = image_val.copy()
      NUM_CLASSES = 90
      IMAGE_SIZE = (12, 8)
      PATH_TO_LABELS = '../../data/mscoco_label_map.pbtxt'
      label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
      categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
                                                                  use_display_name=True)
      category_index = label_map_util.create_category_index(categories)
      vis_util.visualize_boxes_and_labels_on_image_array(
                image_np,
                boxes_val,
                np.squeeze(classes_val).astype(np.int32),
                np.squeeze(scores_val),
                category_index,
                use_normalized_coordinates=True,
                line_thickness=8)
      plt.figure(figsize=IMAGE_SIZE)
      plt.subplot(121)
      plt.imshow(image_np)
      plt.subplot(122)
      plt.imshow(image_np_origin)
      plt.show()  
      print('finish')
           
           
      pass
  coord.request_stop()
  coord.join(threads)
Esempio n. 31
0
def train(datasets_dicts,
          epochs,
          val_every,
          iters_cnt,
          validate_with_eval_model,
          pipeline_config,
          num_clones=1,
          save_cback=None):
    logger.info('Start train')
    configs = configs_from_pipeline(pipeline_config)

    model_config = configs['model']
    train_config = configs['train_config']

    create_model_fn = functools.partial(
        model_builder.build,
        model_config=model_config,
        is_training=True)
    detection_model = create_model_fn()

    def get_next(dataset):
        return dataset_util.make_initializable_iterator(
            build_dataset(dataset)).get_next()

    create_tensor_dict_fn = functools.partial(get_next, datasets_dicts['train'])
    create_tensor_dict_fn_val = functools.partial(get_next, datasets_dicts['val'])

    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=4,
            clone_on_cpu=False,
            replica_id=0,
            num_replicas=1,
            num_ps_tasks=0,
            worker_job_name='lonely_worker')

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        with tf.device(deploy_config.inputs_device()):
            coord = coordinator.Coordinator()
            input_queue = create_input_queue(
                train_config.batch_size, create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity, data_augmentation_options)

            input_queue_val = create_input_queue(
                train_config.batch_size, create_tensor_dict_fn_val,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity, data_augmentation_options)

        # create validation graph
        create_model_fn_val = functools.partial(
            model_builder.build,
            model_config=model_config,
            is_training=not validate_with_eval_model)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                train_config.optimizer)
            for var in optimizer_summary_vars:
                tf.summary.scalar(var.op.name, var, family='LearningRate')

        train_losses = []
        grads_and_vars = []
        with slim.arg_scope([slim.model_variable, slim.variable], device='/device:CPU:0'):
            for curr_dev_id in range(num_clones):
                with tf.device('/gpu:{}'.format(curr_dev_id)):
                    with tf.name_scope('clone_{}'.format(curr_dev_id)) as scope:
                        with tf.variable_scope(tf.get_variable_scope(),
                                               reuse=True if curr_dev_id > 0 else None):
                            losses = _create_losses_val(input_queue, create_model_fn, train_config)
                            clones_loss = tf.add_n(losses)
                            clones_loss = tf.divide(clones_loss, 1.0 * num_clones)
                            grads = training_optimizer.compute_gradients(clones_loss)
                            train_losses.append(clones_loss)
                            grads_and_vars.append(grads)
                            if curr_dev_id == 0:
                                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        val_total_loss = get_val_loss(num_clones, input_queue_val, create_model_fn_val, train_config)

        with tf.device(deploy_config.optimizer_device()):
            total_loss = tf.add_n(train_losses)
            grads_and_vars = model_deploy._sum_clones_gradients(grads_and_vars)
            total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(grads_and_vars,
                                                              global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops, name='update_barrier')
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
        coord.clear_stop()
        sess = tf.Session(config=config)
        saver = tf.train.Saver()

        graph = ops.get_default_graph()
        with graph.as_default():
            with ops.name_scope('init_ops'):
                init_op = variables.global_variables_initializer()
                ready_op = variables.report_uninitialized_variables()
                local_init_op = control_flow_ops.group(
                        variables.local_variables_initializer(),
                        lookup_ops.tables_initializer())

        # graph.finalize()
        sess.run([init_op, ready_op, local_init_op])

        queue_runners = graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
        threads = []
        for qr in queue_runners:
            threads.extend(qr.create_threads(sess, coord=coord, daemon=True, start=True))

        logger.info('Start restore')
        if train_config.fine_tune_checkpoint:
            var_map = detection_model.restore_map(
                            fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
                            load_all_detection_checkpoint_vars=(
                                train_config.load_all_detection_checkpoint_vars))
            available_var_map = (variables_helper.
                                    get_variables_available_in_checkpoint(
                                    var_map, train_config.fine_tune_checkpoint))
            if 'global_step' in available_var_map:
                del available_var_map['global_step']
            init_saver = tf.train.Saver(available_var_map)
            logger.info('Restoring model weights from previous checkpoint.')
            init_saver.restore(sess, train_config.fine_tune_checkpoint)
            logger.info('Model restored.')

        eval_planner = EvalPlanner(epochs, val_every)
        progress = sly.progress_counter_train(epochs, iters_cnt['train'])
        best_val_loss = float('inf')
        epoch_flt = 0

        for epoch in range(epochs):
            logger.info("Before new epoch", extra={'epoch': epoch_flt})
            for train_it in range(iters_cnt['train']):
                total_loss, np_global_step = sess.run([train_tensor, global_step])

                metrics_values_train = {
                    'loss': total_loss,
                }

                progress.iter_done_report()
                epoch_flt = epoch_float(epoch, train_it + 1, iters_cnt['train'])
                sly.report_metrics_training(epoch_flt, metrics_values_train)

                if eval_planner.need_validation(epoch_flt):
                    logger.info("Before validation", extra={'epoch': epoch_flt})

                    overall_val_loss = 0
                    for val_it in range(iters_cnt['val']):
                        overall_val_loss += sess.run(val_total_loss)

                        logger.info("Validation in progress", extra={'epoch': epoch_flt,
                                                                     'val_iter': val_it,
                                                                     'val_iters': iters_cnt['val']})

                    metrics_values_val = {
                        'loss': overall_val_loss / iters_cnt['val'],
                    }
                    sly.report_metrics_validation(epoch_flt, metrics_values_val)
                    logger.info("Validation has been finished", extra={'epoch': epoch_flt})

                    eval_planner.validation_performed()

                    val_loss = metrics_values_val['loss']
                    model_is_best = val_loss < best_val_loss
                    if model_is_best:
                        best_val_loss = val_loss
                        logger.info('It\'s been determined that current model is the best one for a while.')

                    save_cback(saver,
                               sess,
                               model_is_best,
                               opt_data={
                                         'epoch': epoch_flt,
                                         'val_metrics': metrics_values_val,
                               })

            logger.info("Epoch was finished", extra={'epoch': epoch_flt})
        coord.request_stop()
        coord.join(threads)
Esempio n. 32
0
    def _train_input_fn(params=None):
        """Returns `features` and `labels` tensor dictionaries for training.

    Args:
      params: Parameter dictionary passed from the estimator.

    Returns:
      features: Dictionary of feature tensors.
        features[fields.InputDataFields.image] is a [batch_size, H, W, C]
          float32 tensor with preprocessed images.
        features[HASH_KEY] is a [batch_size] int32 tensor representing unique
          identifiers for the images.
        features[fields.InputDataFields.true_image_shape] is a [batch_size, 3]
          int32 tensor representing the true image shapes, as preprocessed
          images could be padded.
      labels: Dictionary of groundtruth tensors.
        labels[fields.InputDataFields.num_groundtruth_boxes] is a [batch_size]
          int32 tensor indicating the number of groundtruth boxes.
        labels[fields.InputDataFields.groundtruth_boxes] is a
          [batch_size, num_boxes, 4] float32 tensor containing the corners of
          the groundtruth boxes.
        labels[fields.InputDataFields.groundtruth_classes] is a
          [batch_size, num_boxes, num_classes] float32 one-hot tensor of
          classes.
        labels[fields.InputDataFields.groundtruth_weights] is a
          [batch_size, num_boxes] float32 tensor containing groundtruth weights
          for the boxes.
        -- Optional --
        labels[fields.InputDataFields.groundtruth_instance_masks] is a
          [batch_size, num_boxes, H, W] float32 tensor containing only binary
          values, which represent instance masks for objects.
        labels[fields.InputDataFields.groundtruth_keypoints] is a
          [batch_size, num_boxes, num_keypoints, 2] float32 tensor containing
          keypoints for each box.

    Raises:
      TypeError: if the `train_config` or `train_input_config` are not of the
        correct type.
    """
        if not isinstance(train_config, train_pb2.TrainConfig):
            raise TypeError('For training mode, the `train_config` must be a '
                            'train_pb2.TrainConfig.')
        if not isinstance(train_input_config, input_reader_pb2.InputReader):
            raise TypeError('The `train_input_config` must be a '
                            'input_reader_pb2.InputReader.')
        if not isinstance(model_config, model_pb2.DetectionModel):
            raise TypeError('The `model_config` must be a '
                            'model_pb2.DetectionModel.')

        data_augmentation_options = [
            preprocessor_builder.build(step)
            for step in train_config.data_augmentation_options
        ]
        data_augmentation_fn = functools.partial(
            augment_input_data,
            data_augmentation_options=data_augmentation_options)

        model = model_builder.build(model_config, is_training=True)
        image_resizer_config = config_util.get_image_resizer_config(
            model_config)
        image_resizer_fn = image_resizer_builder.build(image_resizer_config)

        transform_data_fn = functools.partial(
            transform_input_data,
            model_preprocess_fn=model.preprocess,
            image_resizer_fn=image_resizer_fn,
            num_classes=config_util.get_number_of_classes(model_config),
            data_augmentation_fn=data_augmentation_fn)
        dataset = dataset_builder.build(
            train_input_config,
            transform_input_data_fn=transform_data_fn,
            batch_size=params['batch_size']
            if params else train_config.batch_size,
            max_num_boxes=train_config.max_number_of_boxes,
            num_classes=config_util.get_number_of_classes(model_config),
            spatial_image_shape=config_util.get_spatial_image_size(
                image_resizer_config))
        tensor_dict = dataset_util.make_initializable_iterator(
            dataset).get_next()

        hash_from_source_id = tf.string_to_hash_bucket_fast(
            tensor_dict[fields.InputDataFields.source_id], HASH_BINS)
        features = {
            fields.InputDataFields.image:
            tensor_dict[fields.InputDataFields.image],
            HASH_KEY:
            tf.cast(hash_from_source_id, tf.int32),
            fields.InputDataFields.true_image_shape:
            tensor_dict[fields.InputDataFields.true_image_shape]
        }

        labels = {
            fields.InputDataFields.num_groundtruth_boxes:
            tensor_dict[fields.InputDataFields.num_groundtruth_boxes],
            fields.InputDataFields.groundtruth_boxes:
            tensor_dict[fields.InputDataFields.groundtruth_boxes],
            fields.InputDataFields.groundtruth_classes:
            tensor_dict[fields.InputDataFields.groundtruth_classes],
            fields.InputDataFields.groundtruth_weights:
            tensor_dict[fields.InputDataFields.groundtruth_weights]
        }
        if fields.InputDataFields.groundtruth_keypoints in tensor_dict:
            labels[fields.InputDataFields.groundtruth_keypoints] = tensor_dict[
                fields.InputDataFields.groundtruth_keypoints]
        if fields.InputDataFields.groundtruth_instance_masks in tensor_dict:
            labels[fields.InputDataFields.
                   groundtruth_instance_masks] = tensor_dict[
                       fields.InputDataFields.groundtruth_instance_masks]

        return features, labels
Esempio n. 33
0
def train(create_tensor_dict_fn,
          create_model_fn,
          train_config,
          master,
          task,
          num_clones,
          worker_replicas,
          clone_on_cpu,
          ps_tasks,
          worker_job_name,
          is_chief,
          train_dir,
          graph_hook_fn=None):
    """Training function for detection models.

    Args:
      create_tensor_dict_fn: a function to create a tensor input dictionary.
      create_model_fn: a function that creates a DetectionModel and generates
                       losses.
      train_config: a train_pb2.TrainConfig protobuf.
      master: BNS name of the TensorFlow master to use.
      task: The task id of this training instance.
      num_clones: The number of clones to run per machine.
      worker_replicas: The number of work replicas to train with.
      clone_on_cpu: True if clones should be forced to run on CPU.
      ps_tasks: Number of parameter server tasks.
      worker_job_name: Name of the worker job.
      is_chief: Whether this replica is the chief replica.
      train_dir: Directory to write checkpoints and training summaries to.
      graph_hook_fn: Optional function that is called after the inference graph is
        built (before optimization). This is helpful to perform additional changes
        to the training graph such as adding FakeQuant ops. The function should
        modify the default graph.

    Raises:
      ValueError: If both num_clones > 1 and train_config.sync_replicas is true.
    """

    detection_model = create_model_fn()
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        if num_clones != 1 and train_config.sync_replicas:
            raise ValueError('In Synchronous SGD mode num_clones must ',
                             'be 1. Found num_clones: {}'.format(num_clones))
        batch_size = train_config.batch_size // num_clones
        if train_config.sync_replicas:
            batch_size //= train_config.replicas_to_aggregate

        with tf.device(deploy_config.inputs_device()):
            input_queue = create_input_queue(
                batch_size, create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity, data_augmentation_options)

        # Gather initial summaries.
        # TODO(rathodv): See if summaries can be added/extracted from global tf
        # collections so that they don't have to be passed around.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        model_fn = functools.partial(_create_losses,
                                     create_model_fn=create_model_fn,
                                     train_config=train_config)
        clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue])
        first_clone_scope = clones[0].scope

        if graph_hook_fn:
            with tf.device(deploy_config.variables_device()):
                graph_hook_fn()

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                train_config.optimizer)
            for var in optimizer_summary_vars:
                tf.summary.scalar(var.op.name, var, family='LearningRate')

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.train.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=worker_replicas)
            sync_optimizer = training_optimizer

        with tf.device(deploy_config.optimizer_device()):
            regularization_losses = (None if train_config.add_regularization_loss
                                     else [])
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, training_optimizer,
                regularization_losses=regularization_losses)
            total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(grads_and_vars,
                                                              global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops, name='update_barrier')
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(tf.summary.histogram('ModelVars/' +
                                                      model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(tf.summary.scalar('Losses/' + loss_tensor.op.name,
                                                   loss_tensor))
        global_summaries.add(
            tf.summary.scalar('Losses/TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                           first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            if not train_config.fine_tune_checkpoint_type:
                # train_config.from_detection_checkpoint field is deprecated. For
                # backward compatibility, fine_tune_checkpoint_type is set based on
                # from_detection_checkpoint.
                if train_config.from_detection_checkpoint:
                    train_config.fine_tune_checkpoint_type = 'detection'
                else:
                    train_config.fine_tune_checkpoint_type = 'classification'
            var_map = detection_model.restore_map(
                fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
                load_all_detection_checkpoint_vars=(
                    train_config.load_all_detection_checkpoint_vars))
            available_var_map = (variables_helper.
                get_variables_available_in_checkpoint(
                var_map, train_config.fine_tune_checkpoint,
                include_global_step=False))
            init_saver = tf.train.Saver(available_var_map)

            def initializer_fn(sess):
                init_saver.restore(sess, train_config.fine_tune_checkpoint)
                # optimistic_restore(init_saver, sess, train_config.fine_tune_checkpoint)

            init_fn = initializer_fn

        slim.learning.train(
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(
                train_config.num_steps if train_config.num_steps else None),
            save_summaries_secs=120,
            sync_optimizer=sync_optimizer,
            saver=saver)
Esempio n. 34
0
def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
          num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name,
          is_chief, train_dir):
    """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
  """

    detection_model = create_model_fn()  #Object for create the detection model
    data_augmentation_options = [  #for ssd it's ssd random crop 
        preprocessor_builder.build(
            step)  #random_horizontal_flip in the faster rcnn config file 
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default(
    ):  #we need a default graph in order to create the model
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.    #global step is needed to keep the records
        with tf.device(deploy_config.variables_device()
                       ):  #suitable device for operation  +++On CPU I think
            global_step = slim.create_global_step(
            )  #created the global step tensor


#The following will create an input Que images ,boxes m targets
        with tf.device(deploy_config.inputs_device()
                       ):  #Device to use to build the inputs ++++on CPU ??
            input_queue = _create_input_queue(
                train_config.batch_size //
                num_clones,  #here batch size/number_clones 
                create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options)  #random_horizontal_flip

        # Gather initial summaries.
        summaries = set(tf.get_collection(
            tf.GraphKeys.SUMMARIES))  #vreate the summeries
        global_summaries = set([])
        #Creating the loss
        model_fn = functools.partial(
            _create_losses,  #This will create the losses , It need a object of our model as an argivement 
            create_model_fn=create_model_fn)
        clones = model_deploy.create_clones(
            deploy_config, model_fn,
            [input_queue
             ])  #creating the clones with respect to t he input model fn
        first_clone_scope = clones[0].scope

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):  #This is important
            training_optimizer = optimizer_builder.build(
                train_config.optimizer,  #optimization 
                global_summaries
            )  #will select rms_prop , Adam Here derectly we get the optimizer

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.SyncReplicasOptimizer(  #This is more of synchronising the optimizer because there are repicas doing optimizing
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=train_config.worker_replicas)
            sync_optimizer = training_optimizer

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:  #This is the checkpoint path file
            init_fn = detection_model.restore_fn(  #Re storing the weights from the feature extractors 
                train_config.fine_tune_checkpoint,
                from_detection_checkpoint=train_config.
                from_detection_checkpoint
            )  #This is more of the initializer which is re-stored from check points

        with tf.device(deploy_config.optimizer_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(  #This gives the total loss and also the grad and var pairs (Tuple) 
                clones,
                training_optimizer,
                regularization_losses=None)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:  #We have not initialized a bias gradient multiplier
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:  #Here we are not freezing any may be it's good to freeze the
                #This will be usefult to go through the variables
                print("Priting the grad_and_vars to check the tuples ")
                print(grad_and_vars)
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(  #input to this also grads and vars which means 
                    grads_and_vars,
                    train_config.freeze_variables)  #This function will output
                #We are getiing gradients and of their varaibles exept the froxen list
            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars,  #updating the gradinets list 
                global_step=global_step)
            update_ops.append(grad_updates)  #Here the new updated variables

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(  #saving the checkpoints 
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        slim.learning.train(  #Training the network using a compact function 
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            save_summaries_secs=120,
            sync_optimizer=sync_optimizer,
            saver=saver)
Esempio n. 35
0
    model = functools.partial(model_builder.build,
                              model_config=model_config,
                              is_training=True)
    input = functools.partial(get_next, input_config)
    #trainer.train(input, model, train_config, 'lonely_worker', FLAGS.train_dir)

    master = ''
    task = 0
    worker_replicas = 1
    ps_tasks = 0
    num_clones = 1
    clone_on_cpu = False
    is_chief = True
    detection_model = model()
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():

        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name='lonely_worker')

        #print(deploy_config.variables_device(), 'FTTTTTTG')
        with tf.device(deploy_config.variables_device()):
Esempio n. 36
0
def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
          num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name,
          is_chief, train_dir):
    """Training function for detection models.

    Args:
      create_tensor_dict_fn: a function to create a tensor input dictionary.
      create_model_fn: a function that creates a DetectionModel and generates
                       losses.
      train_config: a train_pb2.TrainConfig protobuf.
      master: BNS name of the TensorFlow master to use.
      task: The task id of this training instance.
      num_clones: The number of clones to run per machine.
      worker_replicas: The number of work replicas to train with.
      clone_on_cpu: True if clones should be forced to run on CPU.
      ps_tasks: Number of parameter server tasks.
      worker_job_name: Name of the worker job.
      is_chief: Whether this replica is the chief replica.
      train_dir: Directory to write checkpoints and training summaries to.
    """

    detection_model = create_model_fn()
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        with tf.device(deploy_config.inputs_device()):
            input_queue = create_input_queue(
                train_config.batch_size // num_clones, create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options)

        # Gather initial summaries.
        # TODO(rathodv): See if summaries can be added/extracted from global tf
        # collections so that they don't have to be passed around.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        model_fn = functools.partial(_create_losses,
                                     create_model_fn=create_model_fn,
                                     train_config=train_config)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue])
        first_clone_scope = clones[0].scope

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer = optimizer_builder.build(
                train_config.optimizer, global_summaries)

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=train_config.worker_replicas)
            sync_optimizer = training_optimizer

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            var_map = detection_model.restore_map(
                from_detection_checkpoint=train_config.
                from_detection_checkpoint)
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map, train_config.fine_tune_checkpoint))
            init_saver = tf.train.Saver(available_var_map)

            def initializer_fn(sess):
                init_saver.restore(sess, train_config.fine_tune_checkpoint)

            init_fn = initializer_fn

        with tf.device(deploy_config.optimizer_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, training_optimizer, regularization_losses=None)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)
        # Don't allocate all avaliable GPU
        session_config.gpu_options.allow_growth = True

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            max_to_keep=10,
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        slim.learning.train(
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            save_summaries_secs=120,
            sync_optimizer=sync_optimizer,
            saver=saver)
Esempio n. 37
0
def train(create_tensor_dict_fn,
          create_model_fn,
          train_config,
          master,
          task,
          num_clones,
          worker_replicas,
          clone_on_cpu,
          ps_tasks,
          worker_job_name,
          is_chief,
          train_dir,
          graph_hook_fn=None):
  """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
    graph_hook_fn: Optional function that is called after the training graph is
      completely built. This is helpful to perform additional changes to the
      training graph such as optimizing batchnorm. The function should modify
      the default graph.
  """

  detection_model = create_model_fn()
  data_augmentation_options = [
      preprocessor_builder.build(step)
      for step in train_config.data_augmentation_options]

  with tf.Graph().as_default():
    # Build a configuration specifying multi-GPU and multi-replicas.
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=num_clones,
        clone_on_cpu=clone_on_cpu,
        replica_id=task,
        num_replicas=worker_replicas,
        num_ps_tasks=ps_tasks,
        worker_job_name=worker_job_name)

    # Place the global step on the device storing the variables.
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    with tf.device(deploy_config.inputs_device()):
      input_queue = create_input_queue(
          train_config.batch_size // num_clones, create_tensor_dict_fn,
          train_config.batch_queue_capacity,
          train_config.num_batch_queue_threads,
          train_config.prefetch_queue_capacity, data_augmentation_options)

    # Gather initial summaries.
    # TODO(rathodv): See if summaries can be added/extracted from global tf
    # collections so that they don't have to be passed around.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
    global_summaries = set([])

    model_fn = functools.partial(_create_losses,
                                 create_model_fn=create_model_fn,
                                 train_config=train_config)
    clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue])
    first_clone_scope = clones[0].scope

    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by model_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    with tf.device(deploy_config.optimizer_device()):
      training_optimizer, optimizer_summary_vars = optimizer_builder.build(
          train_config.optimizer)
      for var in optimizer_summary_vars:
        tf.summary.scalar(var.op.name, var, family='LearningRate')

    sync_optimizer = None
    if train_config.sync_replicas:
      training_optimizer = tf.train.SyncReplicasOptimizer(
          training_optimizer,
          replicas_to_aggregate=train_config.replicas_to_aggregate,
          total_num_replicas=worker_replicas)
      sync_optimizer = training_optimizer

    with tf.device(deploy_config.optimizer_device()):
      regularization_losses = (None if train_config.add_regularization_loss
                               else [])
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, training_optimizer,
          regularization_losses=regularization_losses)
      total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.')

      # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
      if train_config.bias_grad_multiplier:
        biases_regex_list = ['.*/biases']
        grads_and_vars = variables_helper.multiply_gradients_matching_regex(
            grads_and_vars,
            biases_regex_list,
            multiplier=train_config.bias_grad_multiplier)

      # Optionally freeze some layers by setting their gradients to be zero.
      if train_config.freeze_variables:
        grads_and_vars = variables_helper.freeze_gradients_matching_regex(
            grads_and_vars, train_config.freeze_variables)

      # Optionally clip gradients
      if train_config.gradient_clipping_by_norm > 0:
        with tf.name_scope('clip_grads'):
          grads_and_vars = slim.learning.clip_gradient_norms(
              grads_and_vars, train_config.gradient_clipping_by_norm)

      # Create gradient updates.
      grad_updates = training_optimizer.apply_gradients(grads_and_vars,
                                                        global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops, name='update_barrier')
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    if graph_hook_fn:
      with tf.device(deploy_config.variables_device()):
        graph_hook_fn()

    # Add summaries.
    for model_var in slim.get_model_variables():
      global_summaries.add(tf.summary.histogram('ModelVars/' +
                                                model_var.op.name, model_var))
    for loss_tensor in tf.losses.get_losses():
      global_summaries.add(tf.summary.scalar('Losses/' + loss_tensor.op.name,
                                             loss_tensor))
    global_summaries.add(
        tf.summary.scalar('Losses/TotalLoss', tf.losses.get_total_loss()))

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))
    summaries |= global_summaries

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False)

    # Save checkpoints regularly.
    keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
    saver = tf.train.Saver(
        keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

    # Create ops required to initialize the model from a given checkpoint.
    init_fn = None
    if train_config.fine_tune_checkpoint:
      if not train_config.fine_tune_checkpoint_type:
        # train_config.from_detection_checkpoint field is deprecated. For
        # backward compatibility, fine_tune_checkpoint_type is set based on
        # from_detection_checkpoint.
        if train_config.from_detection_checkpoint:
          train_config.fine_tune_checkpoint_type = 'detection'
        else:
          train_config.fine_tune_checkpoint_type = 'classification'
      var_map = detection_model.restore_map(
          fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
          load_all_detection_checkpoint_vars=(
              train_config.load_all_detection_checkpoint_vars))
      available_var_map = (variables_helper.
                           get_variables_available_in_checkpoint(
                               var_map, train_config.fine_tune_checkpoint))
      init_saver = tf.train.Saver(available_var_map)
      def initializer_fn(sess):
        init_saver.restore(sess, train_config.fine_tune_checkpoint)
      init_fn = initializer_fn

    slim.learning.train(
        train_tensor,
        logdir=train_dir,
        master=master,
        is_chief=is_chief,
        session_config=session_config,
        startup_delay_steps=train_config.startup_delay_steps,
        init_fn=init_fn,
        summary_op=summary_op,
        number_of_steps=(
            train_config.num_steps if train_config.num_steps else None),
        save_summaries_secs=120,
        sync_optimizer=sync_optimizer,
        saver=saver)
Esempio n. 38
0
def train(create_tensor_dict_fn,
          create_model_fn,
          train_config,
          master,
          task,
          num_clones,
          worker_replicas,
          clone_on_cpu,
          ps_tasks,
          worker_job_name,
          is_chief,
          train_dir,
          num_examples,
          total_configs,
          model_config,
          is_first_training=True):
    """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
    num_examples: The number of examples in dataset for training.
    total_configs: config list
  """

    detection_model = create_model_fn()
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            if is_first_training:
                global_step = slim.create_global_step()
            else:
                prev_global_step = int(
                    train_config.fine_tune_checkpoint.split('-')[-1])
                global_step = variable_scope.get_variable(
                    ops.GraphKeys.GLOBAL_STEP,
                    dtype=dtypes.int64,
                    initializer=tf.constant(prev_global_step,
                                            dtype=dtypes.int64),
                    trainable=False,
                    collections=[
                        ops.GraphKeys.GLOBAL_VARIABLES,
                        ops.GraphKeys.GLOBAL_STEP
                    ])

        with tf.device(deploy_config.inputs_device()):
            input_queue = _create_input_queue(
                train_config.batch_size // num_clones,
                create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options,
                ignore_options=train_config.ignore_options,
                mtl_window=model_config.mtl.window,
                mtl_edgemask=model_config.mtl.edgemask)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        kwargs = {}
        kwargs['mtl'] = model_config.mtl

        update_schedule = None
        model_fn = functools.partial(
            _create_losses,
            create_model_fn=create_model_fn,
            show_image_summary=train_config.show_image_summary,
            update_schedule=update_schedule,
            **kwargs)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue])
        first_clone_scope = clones[0].scope
        with tf.device(deploy_config.optimizer_device()):
            training_optimizer = optimizer_builder.build(
                train_config.optimizer, global_summaries)

        sync_optimizer = None
        if train_config.sync_replicas:
            # TODO: support syncrhonous update for manual loss update
            training_optimizer = tf.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=train_config.worker_replicas)
            sync_optimizer = training_optimizer

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            var_map = detection_model.restore_map(
                from_detection_checkpoint=train_config.
                from_detection_checkpoint,
                restore_box_predictor=train_config.restore_box_predictor,
                restore_window=train_config.restore_window,
                restore_edgemask=train_config.restore_edgemask,
                restore_closeness=train_config.restore_closeness,
                restore_mtl_refine=train_config.restore_mtl_refine,
            )
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map, train_config.fine_tune_checkpoint))
            init_saver = tf.train.Saver(available_var_map)

            mtl = model_config.mtl
            mtl_init_saver_list = []

            def _get_mtl_init_saver(scope_name):
                _var_map = detection_model._feature_extractor.mtl_restore_from_classification_checkpoint_fn(
                    scope_name)
                if train_config.from_detection_checkpoint:
                    _var_map_new = dict()
                    for name, val in _var_map.iteritems():
                        _var_map_new[detection_model.
                                     second_stage_feature_extractor_scope +
                                     '/' + name] = val
                    _var_map = _var_map_new
                _available_var_map = (
                    variables_helper.get_variables_available_in_checkpoint(
                        _var_map, train_config.fine_tune_checkpoint))
                if _available_var_map:
                    return tf.train.Saver(_available_var_map)
                else:
                    return None

            # if mtl.share_second_stage_init and mtl.shared_feature == 'proposal_feature_maps':
            if mtl.share_second_stage_init and train_config.from_detection_checkpoint == False:
                if mtl.window:
                    mtl_init_saver_list.append(
                        _get_mtl_init_saver(
                            detection_model.window_box_predictor_scope))
                if mtl.closeness:
                    mtl_init_saver_list.append(
                        _get_mtl_init_saver(
                            detection_model.closeness_box_predictor_scope))
                if mtl.edgemask:
                    mtl_init_saver_list.append(
                        _get_mtl_init_saver(
                            detection_model.edgemask_predictor_scope))

            def initializer_fn(sess):
                init_saver.restore(sess, train_config.fine_tune_checkpoint)
                for mtl_init_saver in mtl_init_saver_list:
                    if not mtl_init_saver == None:
                        mtl_init_saver.restore(
                            sess, train_config.fine_tune_checkpoint)

            init_fn = initializer_fn

        def _get_trainable_variables(except_scopes=None):
            trainable_variables = tf.trainable_variables()
            if except_scopes is None:
                return trainable_variables
            for var in tf.trainable_variables():
                if any([scope in var.name for scope in except_scopes]):
                    trainable_variables.remove(var)
            return trainable_variables

        def _get_update_ops(except_scopes=None):
            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)
            if except_scopes is None:
                return update_ops
            for var in tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                         first_clone_scope):
                if any([scope in var.name for scope in except_scopes]):
                    update_ops.remove(var)
            return update_ops

        with tf.device(deploy_config.optimizer_device()):

            def _single_update():
                kwargs = {}
                _training_optimizer = training_optimizer
                kwargs['var_list'] = None
                update_ops = _get_update_ops()
                total_loss, grads_and_vars = model_deploy.optimize_clones(
                    clones,
                    _training_optimizer,
                    regularization_losses=None,
                    **kwargs)

                # Optionaly multiply gradients by train_config.{grad_multiplier,
                # divide_grad_by_batch}.
                if train_config.grad_multiplier or train_config.divide_grad_by_batch:
                    base_multiplier = train_config.grad_multiplier \
                        if train_config.grad_multiplier else 1.0
                    batch_divider = float(train_config.batch_size) \
                        if train_config.divide_grad_by_batch else 1.0
                    total_multiplier = base_multiplier / batch_divider
                    grads_and_vars = variables_helper.multiply_gradients_by_scalar_multiplier(
                        grads_and_vars, multiplier=total_multiplier)

                # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
                if train_config.bias_grad_multiplier:
                    biases_regex_list = ['.*/biases']
                    grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                        grads_and_vars,
                        biases_regex_list,
                        multiplier=train_config.bias_grad_multiplier)

                # Optionally freeze some layers by setting their gradients to be zero.
                if train_config.freeze_variables:
                    grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                        grads_and_vars, train_config.freeze_variables)

                # Optionally clip gradients
                if train_config.gradient_clipping_by_norm > 0:
                    with tf.name_scope('clip_grads'):
                        grads_and_vars = slim.learning.clip_gradient_norms(
                            grads_and_vars,
                            train_config.gradient_clipping_by_norm)

                # Create gradient updates.
                grad_updates = _training_optimizer.apply_gradients(
                    grads_and_vars, global_step=global_step)
                # update_ops.append(grad_updates)
                total_update_ops = update_ops + [grad_updates]

                update_op = tf.group(*total_update_ops)
                with tf.control_dependencies([update_op]):
                    train_tensor = tf.identity(total_loss, name=('train_op'))
                return train_tensor

            train_tensor = _single_update()

        # Add summaries.
        def _get_total_loss_with_collection(collection,
                                            add_regularization_losses=True,
                                            name="total_loss"):
            losses = tf.losses.get_losses(loss_collection=collection)
            if add_regularization_losses:
                losses += tf.losses.get_regularization_losses()
            return math_ops.add_n(losses, name=name)

        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # not contained in global_summaries
        config_summary_list = select_config_summary_list(total_configs,
                                                         as_matrix=False)

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        custom_learning.train(
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            global_step=(None if is_first_training else global_step),
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            log_every_n_steps=(train_config.log_every_n_steps
                               if train_config.log_every_n_steps else None),
            save_summaries_secs=train_config.save_summaries_secs,
            save_interval_secs=train_config.save_interval_secs,
            sync_optimizer=sync_optimizer,
            saver=saver,
            batch_size=train_config.batch_size,
            num_examples=num_examples,
            config_summary_list=config_summary_list)