Beispiel #1
0
  def __init__(self, min_level, max_level, num_scales, aspect_ratios,
               anchor_scale, image_size):
    """Constructs multiscale Mask-RCNN anchors.

    Args:
      min_level: integer number of minimum level of the output feature pyramid.
      max_level: integer number of maximum level of the output feature pyramid.
      num_scales: integer number representing intermediate scales added
        on each level. For instances, num_scales=2 adds two additional
        anchor scales [2^0, 2^0.5] on each level.
      aspect_ratios: list of tuples representing the aspect raito anchors added
        on each level. For instances, aspect_ratios =
        [(1, 1), (1.4, 0.7), (0.7, 1.4)] adds three anchors on each level.
      anchor_scale: float number representing the scale of size of the base
        anchor to the feature stride 2^level.
      image_size: integer number of input image size. The input image has the
        same dimension for width and height. The image_size should be divided by
        the largest feature stride 2^max_level.
    """
    mlperf_log.maskrcnn_print(key=mlperf_log.ASPECT_RATIOS,
                              value=aspect_ratios)
    self.min_level = min_level
    self.max_level = max_level
    self.num_scales = num_scales
    self.aspect_ratios = aspect_ratios
    self.anchor_scale = anchor_scale
    self.image_size = image_size
    self.config = self._generate_configs()
    self.boxes = self._generate_boxes()
def evaluation(eval_estimator, num_epochs, val_json_file):
  """Runs one evluation."""
  mlperf_log.maskrcnn_print(key=mlperf_log.EVAL_START,
                            value=num_epochs)
  mlperf_log.maskrcnn_print(key=mlperf_log.BATCH_SIZE_TEST,
                            value=FLAGS.eval_batch_size)
  predictor = eval_estimator.predict(
      input_fn=dataloader.InputReader(
          FLAGS.validation_file_pattern,
          mode=tf.estimator.ModeKeys.PREDICT),
      yield_single_examples=False)
  # Every predictor.next() gets a batch of prediction (a dictionary).
  predictions = dict()
  for _ in range(FLAGS.eval_samples // FLAGS.eval_batch_size):
    prediction = six.next(predictor)
    image_info = prediction['image_info']
    raw_detections = prediction['detections']
    processed_detections = raw_detections
    for b in range(raw_detections.shape[0]):
      scale = image_info[b][2]
      for box_id in range(raw_detections.shape[1]):
        # Map [y1, x1, y2, x2] -> [x1, y1, w, h] and multiply detections
        # by image scale.
        new_box = raw_detections[b, box_id, :]
        y1, x1, y2, x2 = new_box[1:5]
        new_box[1:5] = scale * np.array([x1, y1, x2 - x1, y2 - y1])
        processed_detections[b, box_id, :] = new_box
    prediction['detections'] = processed_detections

    for k, v in six.iteritems(prediction):
      if k not in predictions:
        predictions[k] = v
      else:
        predictions[k] = np.append(predictions[k], v, axis=0)

  eval_metric = coco_metric.EvaluationMetric(val_json_file)
  eval_results = eval_metric.predict_metric_fn(predictions)
  tf.logging.info('Eval results: %s' % eval_results)
  mlperf_log.maskrcnn_print(key=mlperf_log.EVAL_STOP,
                            value=num_epochs)
  mlperf_log.maskrcnn_print(key=mlperf_log.EVAL_SIZE,
                            value=FLAGS.eval_samples)
  mlperf_log.maskrcnn_print(
      key=mlperf_log.EVAL_ACCURACY,
      value={
          'epoch': num_epochs,
          'box_AP': str(eval_results['AP']),
          'mask_AP': str(eval_results['mask_AP']),
      })

  return eval_results
def write_summary(eval_results, summary_writer, current_step):
    """Write out eval results for the checkpoint."""
    with tf.Graph().as_default():
        summaries = []
        for metric in eval_results:
            summaries.append(
                tf.Summary.Value(tag=metric,
                                 simple_value=eval_results[metric]))
        tf_summary = tf.Summary(value=list(summaries))
        summary_writer.add_summary(tf_summary, current_step)
        mlperf_log.maskrcnn_print(key=mlperf_log.EVAL_TARGET,
                                  value={
                                      'box_AP': BOX_EVAL_TARGET,
                                      'mask_AP': MASK_EVAL_TARGET
                                  })
    def normalize_image(self):
        """Normalize the image to zero mean and unit variance."""
        # The image normalization is identical to Cloud TPU ResNet.
        self._image = tf.image.convert_image_dtype(self._image,
                                                   dtype=tf.float32)
        offset = tf.constant([0.485, 0.456, 0.406])
        offset = tf.expand_dims(offset, axis=0)
        offset = tf.expand_dims(offset, axis=0)
        self._image -= offset

        # This is simlar to `PIXEL_MEANS` in the reference. Reference: https://github.com/ddkang/Detectron/blob/80f329530843e66d07ca39e19901d5f3e5daf009/lib/core/config.py#L909  # pylint: disable=line-too-long
        mlperf_log.maskrcnn_print(key=mlperf_log.INPUT_NORMALIZATION_STD,
                                  value=[0.229, 0.224, 0.225])
        scale = tf.constant([0.229, 0.224, 0.225])
        scale = tf.expand_dims(scale, axis=0)
        scale = tf.expand_dims(scale, axis=0)
        self._image /= scale
 def __init__(self,
              image,
              output_size,
              short_side_image_size,
              long_side_max_image_size,
              boxes=None,
              classes=None,
              masks=None):
     InputProcessor.__init__(self, image, output_size,
                             short_side_image_size,
                             long_side_max_image_size)
     mlperf_log.maskrcnn_print(key=mlperf_log.MIN_IMAGE_SIZE,
                               value=short_side_image_size)
     mlperf_log.maskrcnn_print(key=mlperf_log.MAX_IMAGE_SIZE,
                               value=long_side_max_image_size)
     self._boxes = boxes
     self._classes = classes
     self._masks = masks
def main(argv):
  del argv  # Unused.
  
  if FLAGS.use_tpu:
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
    FLAGS.tpu,
    zone=FLAGS.tpu_zone,
    project=FLAGS.gcp_project)
    tpu_grpc_url = tpu_cluster_resolver.get_master()
    tf.Session.reset(tpu_grpc_url)
  else:
    tpu_cluster_resolver = None

  # Check data path
  if FLAGS.mode in ('train',
                    'train_and_eval') and FLAGS.training_file_pattern is None:
    raise RuntimeError('You must specify --training_file_pattern for training.')
  if FLAGS.mode in ('eval', 'train_and_eval'):
    if FLAGS.validation_file_pattern is None:
      raise RuntimeError('You must specify --validation_file_pattern '
                         'for evaluation.')
    if FLAGS.val_json_file is None:
      raise RuntimeError('You must specify --val_json_file for evaluation.')

  # Parse hparams
  hparams = mask_rcnn_model.default_hparams()
  hparams.parse(FLAGS.hparams)

  # The following is for spatial partitioning. `features` has one tensor while
  # `labels` had 4 + (`max_level` - `min_level` + 1) * 2 tensors. The input
  # partition is performed on `features` and all partitionable tensors of
  # `labels`, see the partition logic below.
  # In the TPUEstimator context, the meaning of `shard` and `replica` is the
  # same; follwing the API, here has mixed use of both.
  if FLAGS.use_spatial_partition:
    # Checks input_partition_dims agrees with num_cores_per_replica.
    if FLAGS.num_cores_per_replica != np.prod(FLAGS.input_partition_dims):
      raise RuntimeError('--num_cores_per_replica must be a product of array'
                         'elements in --input_partition_dims.')

    labels_partition_dims = {
        'mean_num_positives': None,
        'source_ids': None,
        'groundtruth_data': None,
        'image_scales': None,
    }
    # The Input Partition Logic: We partition only the partition-able tensors.
    # Spatial partition requires that the to-be-partitioned tensors must have a
    # dimension that is a multiple of `partition_dims`. Depending on the
    # `partition_dims` and the `image_size` and the `max_level` in hparams, some
    # high-level anchor labels (i.e., `cls_targets` and `box_targets`) cannot
    # be partitioned. For example, when `partition_dims` is [1, 4, 2, 1], image
    # size is 1536, `max_level` is 9, `cls_targets_8` has a shape of
    # [batch_size, 6, 6, 9], which cannot be partitioned (6 % 4 != 0). In this
    # case, the level-8 and level-9 target tensors are not partition-able, and
    # the highest partition-able level is 7.
    image_size = hparams.get('image_size')
    for level in range(hparams.get('min_level'), hparams.get('max_level') + 1):

      def _can_partition(spatial_dim):
        partitionable_index = np.where(
            spatial_dim % np.array(FLAGS.input_partition_dims) == 0)
        return len(partitionable_index[0]) == len(FLAGS.input_partition_dims)

      spatial_dim = image_size // (2 ** level)
      if _can_partition(spatial_dim):
        labels_partition_dims[
            'box_targets_%d' % level] = FLAGS.input_partition_dims
        labels_partition_dims[
            'cls_targets_%d' % level] = FLAGS.input_partition_dims
      else:
        labels_partition_dims['box_targets_%d' % level] = None
        labels_partition_dims['cls_targets_%d' % level] = None
    num_cores_per_replica = FLAGS.num_cores_per_replica
    input_partition_dims = [
        FLAGS.input_partition_dims, labels_partition_dims]
    num_shards = FLAGS.num_cores // num_cores_per_replica
  else:
    num_cores_per_replica = None
    input_partition_dims = None
    num_shards = FLAGS.num_cores

  params = dict(
      hparams.values(),
      num_shards=num_shards,
      num_examples_per_epoch=FLAGS.num_examples_per_epoch,
      use_tpu=FLAGS.use_tpu,
      resnet_checkpoint=FLAGS.resnet_checkpoint,
      val_json_file=FLAGS.val_json_file,
      mode=FLAGS.mode,
      # The following are used by the host_call function.
      model_dir=FLAGS.model_dir,
      iterations_per_loop=FLAGS.iterations_per_loop,
      dynamic_input_shapes=FLAGS.dynamic_input_shapes,
      transpose_input=FLAGS.transpose_input)

  tpu_config = tf.contrib.tpu.TPUConfig(
      FLAGS.iterations_per_loop,
      num_shards=num_shards,
      num_cores_per_replica=num_cores_per_replica,
      input_partition_dims=input_partition_dims,
      per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
  )

  run_config = tf.contrib.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      log_step_count_steps=FLAGS.iterations_per_loop,
      tpu_config=tpu_config,
  )

  if FLAGS.mode != 'eval':
    mlperf_log.maskrcnn_print(key=mlperf_log.RUN_START)
    mlperf_log.maskrcnn_print(key=mlperf_log.TRAIN_LOOP)
    mlperf_log.maskrcnn_print(key=mlperf_log.TRAIN_EPOCH, value=0)

  if FLAGS.mode == 'train':

    max_steps = int((FLAGS.num_epochs * float(FLAGS.num_examples_per_epoch)) /
                    float(FLAGS.train_batch_size))
    if params['dynamic_input_shapes']:
      train_with_dynamic_shapes(params, tpu_cluster_resolver, max_steps,
                                FLAGS.iterations_per_loop)
    else:
      tf.logging.info(params)
      train_estimator = tf.contrib.tpu.TPUEstimator(
          model_fn=mask_rcnn_model.mask_rcnn_model_fn,
          use_tpu=FLAGS.use_tpu,
          train_batch_size=FLAGS.train_batch_size,
          config=run_config,
          params=params)
      train_estimator.train(
          input_fn=dataloader.InputReader(
              FLAGS.training_file_pattern, mode=tf.estimator.ModeKeys.TRAIN),
          max_steps=max_steps)

    if FLAGS.eval_after_training:
      # Run evaluation after training finishes.
      eval_params = dict(
          params,
          use_tpu=FLAGS.use_tpu,
          input_rand_hflip=False,
          resnet_checkpoint=None,
          is_training_bn=False,
          dynamic_input_shapes=False,
          transpose_input=False,
      )

      eval_estimator = tf.contrib.tpu.TPUEstimator(
          model_fn=mask_rcnn_model.mask_rcnn_model_fn,
          use_tpu=FLAGS.use_tpu,
          train_batch_size=FLAGS.train_batch_size,
          eval_batch_size=FLAGS.eval_batch_size,
          predict_batch_size=FLAGS.eval_batch_size,
          config=run_config,
          params=eval_params)

      output_dir = os.path.join(FLAGS.model_dir, 'eval')
      tf.gfile.MakeDirs(output_dir)
      # Summary writer writes out eval metrics.
      summary_writer = tf.summary.FileWriter(output_dir)
      eval_results = evaluation(eval_estimator, FLAGS.num_epochs,
                                params['val_json_file'])
      write_summary(eval_results, summary_writer, max_steps)

      if (eval_results['AP'] >= BOX_EVAL_TARGET and
          eval_results['mask_AP'] >= MASK_EVAL_TARGET):
        mlperf_log.maskrcnn_print(key=mlperf_log.RUN_STOP,
                                  value={'success': 'true'})
      else:
        mlperf_log.maskrcnn_print(key=mlperf_log.RUN_STOP,
                                  value={'success': 'false'})

      summary_writer.close()
    mlperf_log.maskrcnn_print(key=mlperf_log.RUN_FINAL)

  elif FLAGS.mode == 'eval':

    output_dir = os.path.join(FLAGS.model_dir, 'eval')
    tf.gfile.MakeDirs(output_dir)
    # Summary writer writes out eval metrics.
    summary_writer = tf.summary.FileWriter(output_dir)

    eval_params = dict(
        params,
        use_tpu=FLAGS.use_tpu,
        input_rand_hflip=False,
        resnet_checkpoint=None,
        is_training_bn=False,
        transpose_input=False,
    )

    eval_estimator = tf.contrib.tpu.TPUEstimator(
        model_fn=mask_rcnn_model.mask_rcnn_model_fn,
        use_tpu=FLAGS.use_tpu,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.eval_batch_size,
        config=run_config,
        params=eval_params)

    def terminate_eval():
      tf.logging.info('Terminating eval after %d seconds of no checkpoints' %
                      FLAGS.eval_timeout)
      return True

    # Run evaluation when there's a new checkpoint
    for ckpt in tf.contrib.training.checkpoints_iterator(
        FLAGS.model_dir,
        min_interval_secs=FLAGS.min_eval_interval,
        timeout=FLAGS.eval_timeout,
        timeout_fn=terminate_eval):
      # Terminate eval job when final checkpoint is reached
      current_step = int(os.path.basename(ckpt).split('-')[1])

      tf.logging.info('Starting to evaluate.')
      try:

        current_epoch = (float(current_step * FLAGS.num_examples_per_epoch) /
                         FLAGS.train_batch_size)
        eval_results = evaluation(eval_estimator, current_epoch,
                                  params['val_json_file'])
        write_summary(eval_results, summary_writer, current_step)
        if (eval_results['AP'] >= BOX_EVAL_TARGET and
            eval_results['mask_AP'] >= MASK_EVAL_TARGET):
          mlperf_log.maskrcnn_print(key=mlperf_log.RUN_STOP,
                                    value={'success': 'true'})
          break

        total_step = int(
            (FLAGS.num_epochs * float(FLAGS.num_examples_per_epoch)) / float(
                FLAGS.train_batch_size))
        if current_step >= total_step:
          tf.logging.info('Evaluation finished after training step %d' %
                          current_step)
          break

      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        tf.logging.info('Checkpoint %s no longer exists, skipping checkpoint' %
                        ckpt)

    summary_writer.close()

  elif FLAGS.mode == 'train_and_eval':

    output_dir = os.path.join(FLAGS.model_dir, 'eval')
    tf.gfile.MakeDirs(output_dir)
    summary_writer = tf.summary.FileWriter(output_dir)
    train_estimator = tf.contrib.tpu.TPUEstimator(
        model_fn=mask_rcnn_model.mask_rcnn_model_fn,
        use_tpu=FLAGS.use_tpu,
        train_batch_size=FLAGS.train_batch_size,
        config=run_config,
        params=params)
    eval_params = dict(
        params,
        use_tpu=FLAGS.use_tpu,
        input_rand_hflip=False,
        resnet_checkpoint=None,
        is_training_bn=False,
        dynamic_input_shapes=False
    )
    eval_estimator = tf.contrib.tpu.TPUEstimator(
        model_fn=mask_rcnn_model.mask_rcnn_model_fn,
        use_tpu=FLAGS.use_tpu,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.eval_batch_size,
        config=run_config,
        params=eval_params)
    run_success = False
    steps_per_epoch = int(FLAGS.num_examples_per_epoch /
                          FLAGS.train_batch_size)
    for cycle in range(int(math.floor(FLAGS.num_epochs))):
      tf.logging.info('Starting training cycle, epoch: %d.' % cycle)
      mlperf_log.maskrcnn_print(key=mlperf_log.TRAIN_EPOCH, value=cycle)
      if params['dynamic_input_shapes']:
        tf.logging.info('Use dynamic input shapes training for %d steps. Train '
                        'to %d steps', steps_per_epoch,
                        (cycle + 1) * steps_per_epoch)
        train_with_dynamic_shapes(
            params, tpu_cluster_resolver, (cycle + 1) * steps_per_epoch,
            FLAGS.iterations_per_loop)
      else:
        train_estimator.train(
            input_fn=dataloader.InputReader(FLAGS.training_file_pattern,
                                            mode=tf.estimator.ModeKeys.TRAIN),
            steps=steps_per_epoch)

      tf.logging.info('Starting evaluation cycle, epoch: %d.' % cycle)
      # Run evaluation after every epoch.
      eval_results = evaluation(eval_estimator, cycle,
                                params['val_json_file'])
      current_step = (cycle + 1) * steps_per_epoch
      write_summary(eval_results, summary_writer, current_step)
      if (eval_results['AP'] >= BOX_EVAL_TARGET and
          eval_results['mask_AP'] >= MASK_EVAL_TARGET):
        mlperf_log.maskrcnn_print(key=mlperf_log.RUN_STOP,
                                  value={'success': 'true'})
        run_success = True
        break

    if not run_success:
      current_epoch = int(math.floor(FLAGS.num_epochs))
      max_steps = int((FLAGS.num_epochs * float(FLAGS.num_examples_per_epoch))
                      / float(FLAGS.train_batch_size))
      # Final epoch.
      tf.logging.info('Starting training cycle, epoch: %d.' % current_epoch)
      mlperf_log.maskrcnn_print(key=mlperf_log.TRAIN_EPOCH,
                                value=current_epoch)
      if params['dynamic_input_shapes']:
        remaining_steps = max_steps - int(current_epoch * steps_per_epoch)
        tf.logging.info('Use dynamic input shapes training for %d steps. Train '
                        'to %d steps', steps_per_epoch, max_steps)
        train_with_dynamic_shapes(
            params, tpu_cluster_resolver, max_steps, remaining_steps)
      else:
        train_estimator.train(
            input_fn=dataloader.InputReader(FLAGS.training_file_pattern,
                                            mode=tf.estimator.ModeKeys.TRAIN),
            max_steps=max_steps)
      eval_results = evaluation(eval_estimator, current_epoch,
                                params['val_json_file'])
      write_summary(eval_results, summary_writer, max_steps)
      if (eval_results['AP'] >= BOX_EVAL_TARGET and
          eval_results['mask_AP'] >= MASK_EVAL_TARGET):
        mlperf_log.maskrcnn_print(key=mlperf_log.RUN_STOP,
                                  value={'success': 'true'})
      else:
        mlperf_log.maskrcnn_print(key=mlperf_log.RUN_STOP,
                                  value={'success': 'false'})
    mlperf_log.maskrcnn_print(key=mlperf_log.RUN_FINAL)
    summary_writer.close()
  else:
    tf.logging.info('Mode not found.')
    def _model_outputs():
        """Generates outputs from the model."""
        fpn_feats, rpn_fn, faster_rcnn_fn, mask_rcnn_fn = model(
            features, labels, all_anchors, mode, params)
        rpn_score_outputs, rpn_box_outputs = rpn_fn(fpn_feats)
        (class_outputs, box_outputs, class_targets, box_targets, box_rois,
         proposal_to_label_map) = faster_rcnn_fn(fpn_feats, rpn_score_outputs,
                                                 rpn_box_outputs)
        encoded_box_targets = mask_rcnn_architecture.encode_box_targets(
            box_rois, box_targets, class_targets, params['bbox_reg_weights'])

        if mode != tf.estimator.ModeKeys.TRAIN:
            # Use TEST.NMS in the reference for this value. Reference: https://github.com/ddkang/Detectron/blob/80f329530843e66d07ca39e19901d5f3e5daf009/lib/core/config.py#L227  # pylint: disable=line-too-long
            mlperf_log.maskrcnn_print(key=mlperf_log.NMS_THRESHOLD,
                                      value=params['test_nms'])

            # The mask branch takes inputs from different places in training vs in
            # eval/predict. In training, the mask branch uses proposals combined with
            # labels to produce both mask outputs and targets. At test time, it uses
            # the post-processed predictions to generate masks.
            # Generate detections one image at a time.
            batch_size, _, _ = class_outputs.get_shape().as_list()
            detections = []
            softmax_class_outputs = tf.nn.softmax(class_outputs)
            for i in range(batch_size):
                detections.append(
                    anchors.generate_detections_per_image_op(
                        softmax_class_outputs[i], box_outputs[i], box_rois[i],
                        labels['source_ids'][i], labels['image_info'][i],
                        params['test_detections_per_image'],
                        params['test_rpn_post_nms_topn'], params['test_nms'],
                        params['bbox_reg_weights']))
            detections = tf.stack(detections, axis=0)
            mask_outputs = mask_rcnn_fn(fpn_feats, detections=detections)
        else:
            (mask_outputs, select_class_targets, select_box_targets,
             select_box_rois, select_proposal_to_label_map,
             mask_targets) = mask_rcnn_fn(fpn_feats, class_targets,
                                          box_targets, box_rois,
                                          proposal_to_label_map)
        # Performs post-processing for eval/predict.
        if mode != tf.estimator.ModeKeys.TRAIN:
            batch_size, num_instances, _, _, _ = mask_outputs.get_shape(
            ).as_list()
            mask_outputs = tf.transpose(mask_outputs, [0, 1, 4, 2, 3])
            # Compute indices for batch, num_detections, and class.
            batch_indices = tf.tile(
                tf.reshape(tf.range(batch_size), [batch_size, 1]),
                [1, num_instances])
            instance_indices = tf.tile(
                tf.reshape(tf.range(num_instances), [1, num_instances]),
                [batch_size, 1])
            class_indices = tf.to_int32(detections[:, :, 6])
            gather_indices = tf.stack(
                [batch_indices, instance_indices, class_indices], axis=2)
            mask_outputs = tf.gather_nd(mask_outputs, gather_indices)
        model_outputs = {
            'rpn_score_outputs': rpn_score_outputs,
            'rpn_box_outputs': rpn_box_outputs,
            'class_outputs': class_outputs,
            'box_outputs': box_outputs,
            'class_targets': class_targets,
            'box_targets': encoded_box_targets,
            'box_rois': box_rois,
            'mask_outputs': mask_outputs,
        }
        if mode == tf.estimator.ModeKeys.TRAIN:
            model_outputs.update({
                'select_class_targets': select_class_targets,
                'select_box_targets': select_box_targets,
                'select_box_rois': select_box_rois,
                'select_proposal_to_label_map': select_proposal_to_label_map,
                'mask_targets': mask_targets,
            })
        else:
            model_outputs.update({'detections': detections})
        return model_outputs
def _model_fn(features, labels, mode, params, model, variable_filter_fn=None):
    """Model defination for the Mask-RCNN model based on ResNet.

  Args:
    features: the input image tensor with shape [batch_size, height, width, 3].
      The height and width are fixed and equal.
    labels: the input labels in a dictionary. The labels include score targets
      and box targets which are dense label maps. The labels are generated from
      get_input_fn function in data/dataloader.py
    mode: the mode of TPUEstimator including TRAIN, EVAL, and PREDICT.
    params: the dictionary defines hyperparameters of model. The default
      settings are in default_hparams function in this file.
    model: the Mask-RCNN model outputs class logits and box regression outputs.
    variable_filter_fn: the filter function that takes trainable_variables and
      returns the variable list after applying the filter rule.

  Returns:
    tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction.
  """
    if mode == tf.estimator.ModeKeys.PREDICT:
        labels = features
        features = labels.pop('images')

    if params['transpose_input'] and mode == tf.estimator.ModeKeys.TRAIN:
        features = tf.transpose(features, [2, 0, 1, 3])

    image_size = params['dynamic_image_size'] if params[
        'dynamic_input_shapes'] else (params['image_size'],
                                      params['image_size'])
    all_anchors = anchors.Anchors(params['min_level'], params['max_level'],
                                  params['num_scales'],
                                  params['aspect_ratios'],
                                  params['anchor_scale'], image_size)

    def _model_outputs():
        """Generates outputs from the model."""
        fpn_feats, rpn_fn, faster_rcnn_fn, mask_rcnn_fn = model(
            features, labels, all_anchors, mode, params)
        rpn_score_outputs, rpn_box_outputs = rpn_fn(fpn_feats)
        (class_outputs, box_outputs, class_targets, box_targets, box_rois,
         proposal_to_label_map) = faster_rcnn_fn(fpn_feats, rpn_score_outputs,
                                                 rpn_box_outputs)
        encoded_box_targets = mask_rcnn_architecture.encode_box_targets(
            box_rois, box_targets, class_targets, params['bbox_reg_weights'])

        if mode != tf.estimator.ModeKeys.TRAIN:
            # Use TEST.NMS in the reference for this value. Reference: https://github.com/ddkang/Detectron/blob/80f329530843e66d07ca39e19901d5f3e5daf009/lib/core/config.py#L227  # pylint: disable=line-too-long
            mlperf_log.maskrcnn_print(key=mlperf_log.NMS_THRESHOLD,
                                      value=params['test_nms'])

            # The mask branch takes inputs from different places in training vs in
            # eval/predict. In training, the mask branch uses proposals combined with
            # labels to produce both mask outputs and targets. At test time, it uses
            # the post-processed predictions to generate masks.
            # Generate detections one image at a time.
            batch_size, _, _ = class_outputs.get_shape().as_list()
            detections = []
            softmax_class_outputs = tf.nn.softmax(class_outputs)
            for i in range(batch_size):
                detections.append(
                    anchors.generate_detections_per_image_op(
                        softmax_class_outputs[i], box_outputs[i], box_rois[i],
                        labels['source_ids'][i], labels['image_info'][i],
                        params['test_detections_per_image'],
                        params['test_rpn_post_nms_topn'], params['test_nms'],
                        params['bbox_reg_weights']))
            detections = tf.stack(detections, axis=0)
            mask_outputs = mask_rcnn_fn(fpn_feats, detections=detections)
        else:
            (mask_outputs, select_class_targets, select_box_targets,
             select_box_rois, select_proposal_to_label_map,
             mask_targets) = mask_rcnn_fn(fpn_feats, class_targets,
                                          box_targets, box_rois,
                                          proposal_to_label_map)
        # Performs post-processing for eval/predict.
        if mode != tf.estimator.ModeKeys.TRAIN:
            batch_size, num_instances, _, _, _ = mask_outputs.get_shape(
            ).as_list()
            mask_outputs = tf.transpose(mask_outputs, [0, 1, 4, 2, 3])
            # Compute indices for batch, num_detections, and class.
            batch_indices = tf.tile(
                tf.reshape(tf.range(batch_size), [batch_size, 1]),
                [1, num_instances])
            instance_indices = tf.tile(
                tf.reshape(tf.range(num_instances), [1, num_instances]),
                [batch_size, 1])
            class_indices = tf.to_int32(detections[:, :, 6])
            gather_indices = tf.stack(
                [batch_indices, instance_indices, class_indices], axis=2)
            mask_outputs = tf.gather_nd(mask_outputs, gather_indices)
        model_outputs = {
            'rpn_score_outputs': rpn_score_outputs,
            'rpn_box_outputs': rpn_box_outputs,
            'class_outputs': class_outputs,
            'box_outputs': box_outputs,
            'class_targets': class_targets,
            'box_targets': encoded_box_targets,
            'box_rois': box_rois,
            'mask_outputs': mask_outputs,
        }
        if mode == tf.estimator.ModeKeys.TRAIN:
            model_outputs.update({
                'select_class_targets': select_class_targets,
                'select_box_targets': select_box_targets,
                'select_box_rois': select_box_rois,
                'select_proposal_to_label_map': select_proposal_to_label_map,
                'mask_targets': mask_targets,
            })
        else:
            model_outputs.update({'detections': detections})
        return model_outputs

    if params['use_bfloat16']:
        with tf.contrib.tpu.bfloat16_scope():
            model_outputs = _model_outputs()

            def cast_outputs_to_float(d):
                for k, v in six.iteritems(d):
                    if isinstance(v, dict):
                        cast_outputs_to_float(v)
                    else:
                        if k != 'select_proposal_to_label_map':
                            d[k] = tf.cast(v, tf.float32)

            cast_outputs_to_float(model_outputs)
    else:
        model_outputs = _model_outputs()

    # First check if it is in PREDICT mode.
    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {}
        predictions['detections'] = model_outputs['detections']
        predictions['mask_outputs'] = tf.nn.sigmoid(
            model_outputs['mask_outputs'])
        predictions['image_info'] = labels['image_info']

        if params['use_tpu']:
            return tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                                   predictions=predictions)
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # Load pretrained model from checkpoint.
    if params['resnet_checkpoint'] and mode == tf.estimator.ModeKeys.TRAIN:

        def scaffold_fn():
            """Loads pretrained model through scaffold function."""
            tf.train.init_from_checkpoint(
                params['resnet_checkpoint'], {
                    '/': 'resnet%s/' % params['resnet_depth'],
                })
            return tf.train.Scaffold()
    else:
        scaffold_fn = None

    # Set up training loss and learning rate.
    update_learning_rate_schedule_parameters(params)
    global_step = tf.train.get_or_create_global_step()
    learning_rate = learning_rate_schedule(params['adjusted_learning_rate'],
                                           params['lr_warmup_init'],
                                           params['lr_warmup_step'],
                                           params['first_lr_drop_step'],
                                           params['second_lr_drop_step'],
                                           global_step)
    # score_loss and box_loss are for logging. only total_loss is optimized.
    total_rpn_loss, rpn_score_loss, rpn_box_loss = rpn_loss(
        model_outputs['rpn_score_outputs'], model_outputs['rpn_box_outputs'],
        labels, params)

    (total_fast_rcnn_loss, fast_rcnn_class_loss,
     fast_rcnn_box_loss) = fast_rcnn_loss(model_outputs['class_outputs'],
                                          model_outputs['box_outputs'],
                                          model_outputs['class_targets'],
                                          model_outputs['box_targets'], params)
    # Only training has the mask loss. Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/model_builder.py  # pylint: disable=line-too-long
    if mode == tf.estimator.ModeKeys.TRAIN:
        mask_loss = mask_rcnn_loss(model_outputs['mask_outputs'],
                                   model_outputs['mask_targets'],
                                   model_outputs['select_class_targets'],
                                   params)
    else:
        mask_loss = 0.
    var_list = variable_filter_fn(
        tf.trainable_variables(),
        params['resnet_depth']) if variable_filter_fn else None
    total_loss = (
        total_rpn_loss + total_fast_rcnn_loss + mask_loss +
        _WEIGHT_DECAY * tf.add_n([
            tf.nn.l2_loss(v) for v in var_list
            if 'batch_normalization' not in v.name and 'bias' not in v.name
        ]))

    host_call = None
    if mode == tf.estimator.ModeKeys.TRAIN:
        mlperf_log.maskrcnn_print(key=mlperf_log.OPT_NAME,
                                  value='tf.train.MomentumOptimizer')
        mlperf_log.maskrcnn_print(key=mlperf_log.OPT_MOMENTUM,
                                  value=params['momentum'])
        mlperf_log.maskrcnn_print(key=mlperf_log.OPT_WEIGHT_DECAY,
                                  value=_WEIGHT_DECAY)
        optimizer = tf.train.MomentumOptimizer(learning_rate,
                                               momentum=params['momentum'])
        optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        grads_and_vars = optimizer.compute_gradients(total_loss, var_list)
        gradients, variables = zip(*grads_and_vars)
        grads_and_vars = []
        # Special treatment for biases (beta is named as bias in reference model)
        # Reference: https://github.com/ddkang/Detectron/blob/80f329530843e66d07ca39e19901d5f3e5daf009/lib/modeling/optimizer.py#L109  # pylint: disable=line-too-long
        for grad, var in zip(gradients, variables):
            if 'beta' in var.name or 'bias' in var.name:
                grad = 2.0 * grad
            grads_and_vars.append((grad, var))
        minimize_op = optimizer.apply_gradients(grads_and_vars,
                                                global_step=global_step)

        with tf.control_dependencies(update_ops):
            train_op = minimize_op

        if params['use_host_call']:

            def host_call_fn(global_step, total_loss, total_rpn_loss,
                             rpn_score_loss, rpn_box_loss,
                             total_fast_rcnn_loss, fast_rcnn_class_loss,
                             fast_rcnn_box_loss, mask_loss, learning_rate):
                """Training host call. Creates scalar summaries for training metrics.

        This function is executed on the CPU and should not directly reference
        any Tensors in the rest of the `model_fn`. To pass Tensors from the
        model to the `metric_fn`, provide as part of the `host_call`. See
        https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
        for more information.

        Arguments should match the list of `Tensor` objects passed as the second
        element in the tuple passed to `host_call`.

        Args:
          global_step: `Tensor with shape `[batch, ]` for the global_step.
          total_loss: `Tensor` with shape `[batch, ]` for the training loss.
          total_rpn_loss: `Tensor` with shape `[batch, ]` for the training RPN
            loss.
          rpn_score_loss: `Tensor` with shape `[batch, ]` for the training RPN
            score loss.
          rpn_box_loss: `Tensor` with shape `[batch, ]` for the training RPN
            box loss.
          total_fast_rcnn_loss: `Tensor` with shape `[batch, ]` for the
            training Mask-RCNN loss.
          fast_rcnn_class_loss: `Tensor` with shape `[batch, ]` for the
            training Mask-RCNN class loss.
          fast_rcnn_box_loss: `Tensor` with shape `[batch, ]` for the
            training Mask-RCNN box loss.
          mask_loss: `Tensor` with shape `[batch, ]` for the training Mask-RCNN
            mask loss.
          learning_rate: `Tensor` with shape `[batch, ]` for the learning_rate.

        Returns:
          List of summary ops to run on the CPU host.
        """
                # Outfeed supports int32 but global_step is expected to be int64.
                global_step = tf.reduce_mean(global_step)
                # Host call fns are executed FLAGS.iterations_per_loop times after one
                # TPU loop is finished, setting max_queue value to the same as number of
                # iterations will make the summary writer only flush the data to storage
                # once per loop.
                with (tf.contrib.summary.create_file_writer(
                        params['model_dir'],
                        max_queue=params['iterations_per_loop']).as_default()):
                    with tf.contrib.summary.always_record_summaries():
                        tf.contrib.summary.scalar('total_loss',
                                                  tf.reduce_mean(total_loss),
                                                  step=global_step)
                        tf.contrib.summary.scalar(
                            'total_rpn_loss',
                            tf.reduce_mean(total_rpn_loss),
                            step=global_step)
                        tf.contrib.summary.scalar(
                            'rpn_score_loss',
                            tf.reduce_mean(rpn_score_loss),
                            step=global_step)
                        tf.contrib.summary.scalar('rpn_box_loss',
                                                  tf.reduce_mean(rpn_box_loss),
                                                  step=global_step)
                        tf.contrib.summary.scalar(
                            'total_fast_rcnn_loss',
                            tf.reduce_mean(total_fast_rcnn_loss),
                            step=global_step)
                        tf.contrib.summary.scalar(
                            'fast_rcnn_class_loss',
                            tf.reduce_mean(fast_rcnn_class_loss),
                            step=global_step)
                        tf.contrib.summary.scalar(
                            'fast_rcnn_box_loss',
                            tf.reduce_mean(fast_rcnn_box_loss),
                            step=global_step)
                        tf.contrib.summary.scalar('mask_loss',
                                                  tf.reduce_mean(mask_loss),
                                                  step=global_step)
                        tf.contrib.summary.scalar(
                            'learning_rate',
                            tf.reduce_mean(learning_rate),
                            step=global_step)

                        return tf.contrib.summary.all_summary_ops()

            # To log the loss, current learning rate, and epoch for Tensorboard, the
            # summary op needs to be run on the host CPU via host_call. host_call
            # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
            # dimension. These Tensors are implicitly concatenated to
            # [params['batch_size']].
            global_step_t = tf.reshape(global_step, [1])
            total_loss_t = tf.reshape(total_loss, [1])
            total_rpn_loss_t = tf.reshape(total_rpn_loss, [1])
            rpn_score_loss_t = tf.reshape(rpn_score_loss, [1])
            rpn_box_loss_t = tf.reshape(rpn_box_loss, [1])
            total_fast_rcnn_loss_t = tf.reshape(total_fast_rcnn_loss, [1])
            fast_rcnn_class_loss_t = tf.reshape(fast_rcnn_class_loss, [1])
            fast_rcnn_box_loss_t = tf.reshape(fast_rcnn_box_loss, [1])
            mask_loss_t = tf.reshape(mask_loss, [1])
            learning_rate_t = tf.reshape(learning_rate, [1])
            host_call = (host_call_fn, [
                global_step_t, total_loss_t, total_rpn_loss_t,
                rpn_score_loss_t, rpn_box_loss_t, total_fast_rcnn_loss_t,
                fast_rcnn_class_loss_t, fast_rcnn_box_loss_t, mask_loss_t,
                learning_rate_t
            ])
    else:
        train_op = None

    return tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                           loss=total_loss,
                                           train_op=train_op,
                                           host_call=host_call,
                                           scaffold_fn=scaffold_fn)
def random_horizontal_flip(image,
                           boxes=None,
                           masks=None,
                           keypoints=None,
                           keypoint_flip_permutation=None,
                           seed=None):
  """Randomly flips the image and detections horizontally.

  The probability of flipping the image is 50%.

  Args:
    image: rank 3 float32 tensor with shape [height, width, channels].
    boxes: (optional) rank 2 float32 tensor with shape [N, 4]
           containing the bounding boxes.
           Boxes are in normalized form meaning their coordinates vary
           between [0, 1].
           Each row is in the form of [ymin, xmin, ymax, xmax].
    masks: (optional) rank 3 float32 tensor with shape
           [num_instances, height, width] containing instance masks. The masks
           are of the same height, width as the input `image`.
    keypoints: (optional) rank 3 float32 tensor with shape
               [num_instances, num_keypoints, 2]. The keypoints are in y-x
               normalized coordinates.
    keypoint_flip_permutation: rank 1 int32 tensor containing the keypoint flip
                               permutation.
    seed: random seed

  Returns:
    image: image which is the same shape as input image.

    If boxes, masks, keypoints, and keypoint_flip_permutation are not None,
    the function also returns the following tensors.

    boxes: rank 2 float32 tensor containing the bounding boxes -> [N, 4].
           Boxes are in normalized form meaning their coordinates vary
           between [0, 1].
    masks: rank 3 float32 tensor with shape [num_instances, height, width]
           containing instance masks.
    keypoints: rank 3 float32 tensor with shape
               [num_instances, num_keypoints, 2]

  Raises:
    ValueError: if keypoints are provided but keypoint_flip_permutation is not.
  """

  def _flip_image(image):
    # flip image
    image_flipped = tf.image.flip_left_right(image)
    return image_flipped

  if keypoints is not None and keypoint_flip_permutation is None:
    raise ValueError(
        'keypoints are provided but keypoints_flip_permutation is not provided')

  with tf.name_scope('RandomHorizontalFlip', values=[image, boxes]):
    result = []
    # random variable defining whether to do flip or not
    mlperf_log.maskrcnn_print(key=mlperf_log.RANDOM_FLIP_PROBABILITY,
                              value=0.5)
    do_a_flip_random = tf.greater(tf.random_uniform([], seed=seed), 0.5)

    # flip image
    image = tf.cond(do_a_flip_random, lambda: _flip_image(image), lambda: image)
    result.append(image)

    # flip boxes
    if boxes is not None:
      boxes = tf.cond(do_a_flip_random, lambda: _flip_boxes_left_right(boxes),
                      lambda: boxes)
      result.append(boxes)

    # flip masks
    if masks is not None:
      masks = tf.cond(do_a_flip_random, lambda: _flip_masks_left_right(masks),
                      lambda: masks)
      result.append(masks)

    # flip keypoints
    if keypoints is not None and keypoint_flip_permutation is not None:
      permutation = keypoint_flip_permutation
      keypoints = tf.cond(
          do_a_flip_random,
          lambda: keypoint_flip_horizontal(keypoints, 0.5, permutation),
          lambda: keypoints)
      result.append(keypoints)

    return tuple(result)
Beispiel #10
0
def print_mlperf(key, value=None):
    if is_main_process():
        maskrcnn_print(key=key, value=value)
def main(argv):
    del argv  # Unused.
    tpu_cluster_resolver = create_tpu_cluster_resolver()
    if tpu_cluster_resolver:
        tpu_grpc_url = tpu_cluster_resolver.get_master()
        tf.Session.reset(tpu_grpc_url)

    # Check data path
    if FLAGS.mode in (
            'train', 'train_and_eval') and FLAGS.training_file_pattern is None:
        raise RuntimeError(
            'You must specify --training_file_pattern for training.')
    if FLAGS.mode in ('eval', 'train_and_eval'):
        if FLAGS.validation_file_pattern is None:
            raise RuntimeError('You must specify --validation_file_pattern '
                               'for evaluation.')
        if FLAGS.val_json_file is None:
            raise RuntimeError(
                'You must specify --val_json_file for evaluation.')

    # Parse hparams
    hparams = mask_rcnn_model.default_hparams()
    hparams.parse(FLAGS.hparams)

    params = dict(
        hparams.values(),
        num_shards=FLAGS.num_cores,
        num_examples_per_epoch=FLAGS.num_examples_per_epoch,
        use_tpu=FLAGS.use_tpu,
        resnet_checkpoint=FLAGS.resnet_checkpoint,
        val_json_file=FLAGS.val_json_file,
        mode=FLAGS.mode,
        # The following are used by the host_call function.
        model_dir=FLAGS.model_dir,
        iterations_per_loop=FLAGS.iterations_per_loop,
        dynamic_input_shapes=FLAGS.dynamic_input_shapes,
        transpose_input=FLAGS.transpose_input)

    tpu_config = tf.contrib.tpu.TPUConfig(
        FLAGS.iterations_per_loop,
        num_shards=FLAGS.num_cores,
        per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.
        PER_HOST_V2)

    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        log_step_count_steps=FLAGS.iterations_per_loop,
        tpu_config=tpu_config,
    )

    if FLAGS.mode != 'eval':
        mlperf_log.maskrcnn_print(key=mlperf_log.RUN_START)
        mlperf_log.maskrcnn_print(key=mlperf_log.TRAIN_LOOP)
        mlperf_log.maskrcnn_print(key=mlperf_log.TRAIN_EPOCH, value=0)

    if FLAGS.mode == 'train':

        max_steps = int(
            (FLAGS.num_epochs * float(FLAGS.num_examples_per_epoch)) /
            float(FLAGS.train_batch_size))
        if params['dynamic_input_shapes']:
            train_with_dynamic_shapes(params, max_steps,
                                      FLAGS.iterations_per_loop)
        else:
            tf.logging.info(params)
            train_estimator = tf.contrib.tpu.TPUEstimator(
                model_fn=mask_rcnn_model.mask_rcnn_model_fn,
                use_tpu=FLAGS.use_tpu,
                train_batch_size=FLAGS.train_batch_size,
                config=run_config,
                params=params)
            train_estimator.train(input_fn=dataloader.InputReader(
                FLAGS.training_file_pattern, mode=tf.estimator.ModeKeys.TRAIN),
                                  max_steps=max_steps)

        if FLAGS.eval_after_training:
            # Run evaluation after training finishes.
            eval_params = dict(
                params,
                use_tpu=FLAGS.use_tpu,
                input_rand_hflip=False,
                resnet_checkpoint=None,
                is_training_bn=False,
                dynamic_input_shapes=False,
                transpose_input=False,
            )

            eval_estimator = tf.contrib.tpu.TPUEstimator(
                model_fn=mask_rcnn_model.mask_rcnn_model_fn,
                use_tpu=FLAGS.use_tpu,
                train_batch_size=FLAGS.train_batch_size,
                eval_batch_size=FLAGS.eval_batch_size,
                predict_batch_size=FLAGS.eval_batch_size,
                config=run_config,
                params=eval_params)

            output_dir = os.path.join(FLAGS.model_dir, 'eval')
            tf.gfile.MakeDirs(output_dir)
            # Summary writer writes out eval metrics.
            summary_writer = tf.summary.FileWriter(output_dir)
            eval_results = evaluation(eval_estimator, FLAGS.num_epochs,
                                      params['val_json_file'])
            write_summary(eval_results, summary_writer, max_steps)

            if (eval_results['AP'] >= BOX_EVAL_TARGET
                    and eval_results['mask_AP'] >= MASK_EVAL_TARGET):
                mlperf_log.maskrcnn_print(key=mlperf_log.RUN_STOP,
                                          value={'success': 'true'})
            else:
                mlperf_log.maskrcnn_print(key=mlperf_log.RUN_STOP,
                                          value={'success': 'false'})

            summary_writer.close()
            mlperf_log.maskrcnn_print(key=mlperf_log.RUN_FINAL)

    elif FLAGS.mode == 'eval':

        output_dir = os.path.join(FLAGS.model_dir, 'eval')
        tf.gfile.MakeDirs(output_dir)
        # Summary writer writes out eval metrics.
        summary_writer = tf.summary.FileWriter(output_dir)

        eval_params = dict(
            params,
            use_tpu=FLAGS.use_tpu,
            input_rand_hflip=False,
            resnet_checkpoint=None,
            is_training_bn=False,
            transpose_input=False,
        )

        eval_estimator = tf.contrib.tpu.TPUEstimator(
            model_fn=mask_rcnn_model.mask_rcnn_model_fn,
            use_tpu=FLAGS.use_tpu,
            train_batch_size=FLAGS.train_batch_size,
            eval_batch_size=FLAGS.eval_batch_size,
            predict_batch_size=FLAGS.eval_batch_size,
            config=run_config,
            params=eval_params)

        def terminate_eval():
            tf.logging.info(
                'Terminating eval after %d seconds of no checkpoints' %
                FLAGS.eval_timeout)
            return True

        run_success = False
        # Run evaluation when there's a new checkpoint
        for ckpt in tf.contrib.training.checkpoints_iterator(
                FLAGS.model_dir,
                min_interval_secs=FLAGS.min_eval_interval,
                timeout=FLAGS.eval_timeout,
                timeout_fn=terminate_eval):
            # Terminate eval job when final checkpoint is reached
            current_step = int(os.path.basename(ckpt).split('-')[1])

            tf.logging.info('Starting to evaluate.')
            try:

                current_epoch = current_step / (float(
                    FLAGS.num_examples_per_epoch) / FLAGS.train_batch_size)
                eval_results = evaluation(eval_estimator, current_epoch,
                                          params['val_json_file'])
                write_summary(eval_results, summary_writer, current_step)
                if (eval_results['AP'] >= BOX_EVAL_TARGET
                        and eval_results['mask_AP'] >= MASK_EVAL_TARGET):
                    mlperf_log.maskrcnn_print(key=mlperf_log.RUN_STOP,
                                              value={'success': 'true'})
                    run_success = True
                    break

                total_step = int(
                    (FLAGS.num_epochs * float(FLAGS.num_examples_per_epoch)) /
                    float(FLAGS.train_batch_size))
                if current_step >= total_step:
                    tf.logging.info(
                        'Evaluation finished after training step %d' %
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint' %
                    ckpt)
        if not run_success:
            mlperf_log.maskrcnn_print(key=mlperf_log.RUN_STOP,
                                      value={'success': 'false'})
        mlperf_log.maskrcnn_print(key=mlperf_log.RUN_FINAL)
        summary_writer.close()

    elif FLAGS.mode == 'train_and_eval':

        output_dir = os.path.join(FLAGS.model_dir, 'eval')
        tf.gfile.MakeDirs(output_dir)
        summary_writer = tf.summary.FileWriter(output_dir)
        train_estimator = tf.contrib.tpu.TPUEstimator(
            model_fn=mask_rcnn_model.mask_rcnn_model_fn,
            use_tpu=FLAGS.use_tpu,
            train_batch_size=FLAGS.train_batch_size,
            config=run_config,
            params=params)
        eval_params = dict(params,
                           use_tpu=FLAGS.use_tpu,
                           input_rand_hflip=False,
                           resnet_checkpoint=None,
                           is_training_bn=False,
                           dynamic_input_shapes=False)
        eval_estimator = tf.contrib.tpu.TPUEstimator(
            model_fn=mask_rcnn_model.mask_rcnn_model_fn,
            use_tpu=FLAGS.use_tpu,
            train_batch_size=FLAGS.train_batch_size,
            eval_batch_size=FLAGS.eval_batch_size,
            predict_batch_size=FLAGS.eval_batch_size,
            config=run_config,
            params=eval_params)
        run_success = False
        steps_per_epoch = int(FLAGS.num_examples_per_epoch /
                              FLAGS.train_batch_size)
        for cycle in range(int(math.floor(FLAGS.num_epochs))):
            tf.logging.info('Starting training cycle, epoch: %d.' % cycle)
            mlperf_log.maskrcnn_print(key=mlperf_log.TRAIN_EPOCH, value=cycle)
            if params['dynamic_input_shapes']:
                tf.logging.info(
                    'Use dynamic input shapes training for %d steps. Train '
                    'to %d steps', steps_per_epoch,
                    (cycle + 1) * steps_per_epoch)
                train_with_dynamic_shapes(params,
                                          (cycle + 1) * steps_per_epoch,
                                          FLAGS.iterations_per_loop)
            else:
                train_estimator.train(input_fn=dataloader.InputReader(
                    FLAGS.training_file_pattern,
                    mode=tf.estimator.ModeKeys.TRAIN),
                                      steps=steps_per_epoch)

            tf.logging.info('Starting evaluation cycle, epoch: %d.' % cycle)
            # Run evaluation after every epoch.
            eval_results = evaluation(eval_estimator, cycle,
                                      params['val_json_file'])
            current_step = (cycle + 1) * steps_per_epoch
            write_summary(eval_results, summary_writer, current_step)
            if (eval_results['AP'] >= BOX_EVAL_TARGET
                    and eval_results['mask_AP'] >= MASK_EVAL_TARGET):
                mlperf_log.maskrcnn_print(key=mlperf_log.RUN_STOP,
                                          value={'success': 'true'})
                run_success = True
                break

        if not run_success:
            current_epoch = int(math.floor(FLAGS.num_epochs))
            max_steps = int(
                (FLAGS.num_epochs * float(FLAGS.num_examples_per_epoch)) /
                float(FLAGS.train_batch_size))
            # Final epoch.
            tf.logging.info('Starting training cycle, epoch: %d.' %
                            current_epoch)
            mlperf_log.maskrcnn_print(key=mlperf_log.TRAIN_EPOCH,
                                      value=current_epoch)
            if params['dynamic_input_shapes']:
                remaining_steps = max_steps - int(
                    current_epoch * steps_per_epoch)
                if remaining_steps > 0:
                    tf.logging.info(
                        'Use dynamic input shapes training for %d steps. '
                        'Train to %d steps', remaining_steps, max_steps)
                    train_with_dynamic_shapes(params, max_steps,
                                              remaining_steps)
            else:
                train_estimator.train(input_fn=dataloader.InputReader(
                    FLAGS.training_file_pattern,
                    mode=tf.estimator.ModeKeys.TRAIN),
                                      max_steps=max_steps)
            eval_results = evaluation(eval_estimator, current_epoch,
                                      params['val_json_file'])
            write_summary(eval_results, summary_writer, max_steps)
            if (eval_results['AP'] >= BOX_EVAL_TARGET
                    and eval_results['mask_AP'] >= MASK_EVAL_TARGET):
                mlperf_log.maskrcnn_print(key=mlperf_log.RUN_STOP,
                                          value={'success': 'true'})
            else:
                mlperf_log.maskrcnn_print(key=mlperf_log.RUN_STOP,
                                          value={'success': 'false'})
        mlperf_log.maskrcnn_print(key=mlperf_log.RUN_FINAL)
        summary_writer.close()
    else:
        tf.logging.info('Mode not found.')
def print_mlperf(key, value=None):

    if get_rank() > 0:
        return
    maskrcnn_print(key=key, value=value)