Ejemplo n.º 1
0
def efficientdet(features, model_name=None, config=None, **kwargs):
    """Build EfficientDet model."""
    if not config and not model_name:
        raise ValueError('please specify either model name or config')

    if not config:
        config = hparams_config.get_efficientdet_config(model_name)
    elif isinstance(config, dict):
        config = hparams_config.Config(config)  # wrap dict in Config object

    if kwargs:
        config.override(kwargs)

    logging.info(config)

    # build backbone features.
    features = build_backbone(features, config)
    logging.info('backbone params/flops = {:.6f}M, {:.9f}B'.format(
        *utils.num_params_flops()))

    # build feature network.
    fpn_feats = build_feature_network(features, config)
    logging.info('backbone+fpn params/flops = {:.6f}M, {:.9f}B'.format(
        *utils.num_params_flops()))

    # build class and box predictions.
    class_outputs, box_outputs = build_class_and_box_outputs(fpn_feats, config)
    logging.info('backbone+fpn+box params/flops = {:.6f}M, {:.9f}B'.format(
        *utils.num_params_flops()))

    return class_outputs, box_outputs
Ejemplo n.º 2
0
def main(_):
    train_examples = info.splits['train'].num_examples
    batch_size = 8
    steps_per_epoch = train_examples // batch_size

    train = dataset['train'].map(
        load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    test = dataset['test'].map(load_image_test)

    train_dataset = train.cache().shuffle(1000).batch(batch_size).repeat()
    train_dataset = train_dataset.prefetch(
        buffer_size=tf.data.experimental.AUTOTUNE)
    test_dataset = test.batch(batch_size)
    config = hparams_config.get_efficientdet_config('efficientdet-d0')
    config.heads = ['segmentation']
    model = efficientdet_keras.EfficientDetNet(config=config)
    model.build((1, 512, 512, 3))
    model.compile(
        optimizer='adam',
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy'])

    val_subsplits = 5
    val_steps = info.splits['test'].num_examples // batch_size // val_subsplits
    model.fit(train_dataset,
              epochs=20,
              steps_per_epoch=steps_per_epoch,
              validation_steps=val_steps,
              validation_data=test_dataset,
              callbacks=[])

    model.save_weights('./test/segmentation')

    print(create_mask(model(tf.ones((1, 512, 512, 3)), False)))
Ejemplo n.º 3
0
def main(_):
    config = hparams_config.get_efficientdet_config(FLAGS.model_name)
    config.override(FLAGS.hparams)
    config.val_json_file = FLAGS.val_json_file
    config.nms_configs.max_nms_inputs = anchors.MAX_DETECTION_POINTS
    config.drop_remainder = False  # eval all examples w/o drop.
    config.image_size = utils.parse_image_size(config['image_size'])

    # Evaluator for AP calculation.
    label_map = label_util.get_label_map(config.label_map)
    evaluator = coco_metric.EvaluationMetric(filename=config.val_json_file,
                                             label_map=label_map)

    # dataset
    batch_size = 1
    ds = dataloader.InputReader(
        FLAGS.val_file_pattern,
        is_training=False,
        max_instances_per_image=config.max_instances_per_image)(
            config, batch_size=batch_size)
    eval_samples = FLAGS.eval_samples
    if eval_samples:
        ds = ds.take((eval_samples + batch_size - 1) // batch_size)

    # Network
    lite_runner = LiteRunner(FLAGS.tflite_path)
    eval_samples = FLAGS.eval_samples or 5000
    pbar = tf.keras.utils.Progbar(
        (eval_samples + batch_size - 1) // batch_size)
    for i, (images, labels) in enumerate(ds):
        cls_outputs, box_outputs = lite_runner.run(images)
        detections = postprocess.generate_detections(config, cls_outputs,
                                                     box_outputs,
                                                     labels['image_scales'],
                                                     labels['source_ids'])
        detections = postprocess.transform_detections(detections)
        evaluator.update_state(labels['groundtruth_data'].numpy(),
                               detections.numpy())
        pbar.update(i)

    # compute the final eval results.
    metrics = evaluator.result()
    metric_dict = {}
    for i, name in enumerate(evaluator.metric_names):
        metric_dict[name] = metrics[i]

    if label_map:
        for i, cid in enumerate(sorted(label_map.keys())):
            name = 'AP_/%s' % label_map[cid]
            metric_dict[name] = metrics[i + len(evaluator.metric_names)]
    print(FLAGS.model_name, metric_dict)
Ejemplo n.º 4
0
 def model_arch(feats, model_name=None, **kwargs):
     """Construct a model arch for keras models."""
     config = hparams_config.get_efficientdet_config(model_name)
     config.override(kwargs)
     model = efficientdet_keras.EfficientDetNet(config=config)
     cls_out_list, box_out_list = model(feats, training=False)
     # convert the list of model outputs to a dictionary with key=level.
     assert len(cls_out_list) == config.max_level - config.min_level + 1
     assert len(box_out_list) == config.max_level - config.min_level + 1
     cls_outputs, box_outputs = {}, {}
     for i in range(config.min_level, config.max_level + 1):
         cls_outputs[i] = cls_out_list[i - config.min_level]
         box_outputs[i] = box_out_list[i - config.min_level]
     return cls_outputs, box_outputs
Ejemplo n.º 5
0
 def build(self, params_override=None):
     """Build model and restore checkpoints."""
     params = copy.deepcopy(self.params)
     if params_override:
         params.update(params_override)
     config = hparams_config.get_efficientdet_config(self.model_name)
     config.override(params)
     if self.only_network:
         self.model = efficientdet_keras.EfficientDetNet(config=config)
     else:
         self.model = efficientdet_keras.EfficientDetModel(config=config)
     image_size = utils.parse_image_size(params['image_size'])
     self.model.build((self.batch_size, *image_size, 3))
     util_keras.restore_ckpt(self.model,
                             self.ckpt_path,
                             self.params['moving_average_decay'],
                             skip_mismatch=False)
Ejemplo n.º 6
0
def main(_):

  # pylint: disable=line-too-long
  # Prepare images and checkpoints: please run these commands in shell.
  # !mkdir tmp
  # !wget https://user-images.githubusercontent.com/11736571/77320690-099af300-6d37-11ea-9d86-24f14dc2d540.png -O tmp/img.png
  # !wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco/efficientdet-d0.tar.gz -O tmp/efficientdet-d0.tar.gz
  # !tar zxf tmp/efficientdet-d0.tar.gz -C tmp
  imgs = [np.array(Image.open(FLAGS.image_path))]
  # Create model config.
  config = hparams_config.get_efficientdet_config(FLAGS.model_name)
  config.is_training_bn = False
  config.image_size = '1920x1280'
  config.nms_configs.score_thresh = 0.4
  config.nms_configs.max_output_size = 100
  config.override(FLAGS.hparams)

  # Use 'mixed_float16' if running on GPUs.
  policy = tf.keras.mixed_precision.Policy('float32')
  tf.keras.mixed_precision.set_global_policy(policy)
  tf.config.run_functions_eagerly(FLAGS.debug)

  # Create and run the model.
  model = efficientdet_keras.EfficientDetModel(config=config)
  model.build((None, None, None, 3))
  model.load_weights(tf.train.latest_checkpoint(FLAGS.model_dir))
  model.summary()

  class ExportModel(tf.Module):

    def __init__(self, model):
      super().__init__()
      self.model = model

    @tf.function
    def f(self, imgs):
      return self.model(imgs, training=False, post_mode='global')

  imgs = tf.convert_to_tensor(imgs, dtype=tf.uint8)
  export_model = ExportModel(model)
  if FLAGS.saved_model_dir:
    tf.saved_model.save(
        export_model,
        FLAGS.saved_model_dir,
        signatures=export_model.f.get_concrete_function(
            tf.TensorSpec(shape=(None, None, None, 3), dtype=tf.uint8)))
    export_model = tf.saved_model.load(FLAGS.saved_model_dir)

  boxes, scores, classes, valid_len = export_model.f(imgs)

  # Visualize results.
  for i, img in enumerate(imgs):
    length = valid_len[i]
    img = inference.visualize_image(
        img,
        boxes[i].numpy()[:length],
        classes[i].numpy().astype(np.int)[:length],
        scores[i].numpy()[:length],
        label_map=config.label_map,
        min_score_thresh=config.nms_configs.score_thresh,
        max_boxes_to_draw=config.nms_configs.max_output_size)
    output_image_path = os.path.join(FLAGS.output_dir, str(i) + '.jpg')
    Image.fromarray(img).save(output_image_path)
    print('writing annotated image to %s' % output_image_path)
Ejemplo n.º 7
0
  def __init__(self,
               model_name: str,
               uri: str,
               hparams: str = '',
               model_dir: Optional[str] = None,
               epochs: int = 50,
               batch_size: int = 64,
               steps_per_execution: int = 1,
               moving_average_decay: int = 0,
               var_freeze_expr: str = '(efficientnet|fpn_cells|resample_p6)',
               tflite_max_detections: int = 25,
               strategy: Optional[str] = None,
               tpu: Optional[str] = None,
               gcp_project: Optional[str] = None,
               tpu_zone: Optional[str] = None,
               use_xla: bool = False,
               profile: bool = False,
               debug: bool = False,
               tf_random_seed: int = 111111,
               verbose: int = 0) -> None:
    """Initialze an instance with model paramaters.

    Args:
      model_name: Model name.
      uri: TF-Hub path/url to EfficientDet module.
      hparams: Hyperparameters used to overwrite default configuration. Can be
        1) Dict, contains parameter names and values; 2) String, Comma separated
        k=v pairs of hyperparameters; 3) String, yaml filename which's a module
        containing attributes to use as hyperparameters.
      model_dir: The location to save the model checkpoint files.
      epochs: Default training epochs.
      batch_size: Training & Evaluation batch size.
      steps_per_execution: Number of steps per training execution.
      moving_average_decay: Float. The decay to use for maintaining moving
        averages of the trained parameters.
      var_freeze_expr: Expression to freeze variables.
      tflite_max_detections: The max number of output detections in the TFLite
        model.
      strategy:  A string specifying which distribution strategy to use.
        Accepted values are 'tpu', 'gpus', None. tpu' means to use TPUStrategy.
        'gpus' mean to use MirroredStrategy for multi-gpus. If None, use TF
        default with OneDeviceStrategy.
      tpu: The Cloud TPU to use for training. This should be either the name
        used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470
          url.
      gcp_project: Project name for the Cloud TPU-enabled project. If not
        specified, we will attempt to automatically detect the GCE project from
        metadata.
      tpu_zone: GCE zone where the Cloud TPU is located in. If not specified, we
        will attempt to automatically detect the GCE project from metadata.
      use_xla: Use XLA even if strategy is not tpu. If strategy is tpu, always
        use XLA, and this flag has no effect.
      profile: Enable profile mode.
      debug: Enable debug mode.
      tf_random_seed: Fixed random seed for deterministic execution across runs
        for debugging.
      verbose: verbosity mode for `tf.keras.callbacks.ModelCheckpoint`, 0 or 1.
    """
    self.model_name = model_name
    self.uri = uri
    self.batch_size = batch_size
    config = hparams_config.get_efficientdet_config(model_name)
    config.override(hparams)
    config.image_size = utils.parse_image_size(config.image_size)
    config.var_freeze_expr = var_freeze_expr
    config.moving_average_decay = moving_average_decay
    config.tflite_max_detections = tflite_max_detections
    if epochs:
      config.num_epochs = epochs

    if use_xla and strategy != 'tpu':
      tf.config.optimizer.set_jit(True)
      for gpu in tf.config.list_physical_devices('GPU'):
        tf.config.experimental.set_memory_growth(gpu, True)

    if debug:
      tf.config.experimental_run_functions_eagerly(True)
      tf.debugging.set_log_device_placement(True)
      os.environ['TF_DETERMINISTIC_OPS'] = '1'
      tf.random.set_seed(tf_random_seed)
      logging.set_verbosity(logging.DEBUG)

    if strategy == 'tpu':
      tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
          tpu, zone=tpu_zone, project=gcp_project)
      tf.config.experimental_connect_to_cluster(tpu_cluster_resolver)
      tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver)
      ds_strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver)
      logging.info('All devices: %s', tf.config.list_logical_devices('TPU'))
      tf.config.set_soft_device_placement(True)
    elif strategy == 'gpus':
      ds_strategy = tf.distribute.MirroredStrategy()
      logging.info('All devices: %s', tf.config.list_physical_devices('GPU'))
    else:
      if tf.config.list_physical_devices('GPU'):
        ds_strategy = tf.distribute.OneDeviceStrategy('device:GPU:0')
      else:
        ds_strategy = tf.distribute.OneDeviceStrategy('device:CPU:0')

    self.ds_strategy = ds_strategy

    if model_dir is None:
      model_dir = tempfile.mkdtemp()
    params = dict(
        profile=profile,
        model_name=model_name,
        steps_per_execution=steps_per_execution,
        model_dir=model_dir,
        strategy=strategy,
        batch_size=batch_size,
        tf_random_seed=tf_random_seed,
        debug=debug,
        verbose=verbose)
    config.override(params, True)
    self.config = config

    # set mixed precision policy by keras api.
    precision = utils.get_precision(config.strategy, config.mixed_precision)
    policy = tf.keras.mixed_precision.experimental.Policy(precision)
    tf.keras.mixed_precision.experimental.set_policy(policy)
Ejemplo n.º 8
0
    def export(self,
               output_dir: Optional[Text] = None,
               tensorrt: Optional[Text] = None,
               tflite: Optional[Text] = None,
               file_pattern: Optional[Text] = None,
               num_calibration_steps: int = 2000):
        """Export a saved model, frozen graph, and potential tflite/tensorrt model.

    Args:
      output_dir: the output folder for saved model.
      tensorrt: If not None, must be {'FP32', 'FP16', 'INT8'}.
      tflite: Type for post-training quantization.
      file_pattern: Glob for tfrecords, e.g. coco/val-*.tfrecord.
      num_calibration_steps: Number of post-training quantization calibration
        steps to run.
    """
        export_model, input_spec = self._get_model_and_spec(tflite)
        image_size = utils.parse_image_size(self.params['image_size'])
        if output_dir:
            tf.saved_model.save(
                export_model,
                output_dir,
                signatures=export_model.__call__.get_concrete_function(
                    input_spec))
            logging.info('Model saved at %s', output_dir)

            # also save freeze pb file.
            graphdef = self.freeze(
                export_model.__call__.get_concrete_function(input_spec))
            proto_path = tf.io.write_graph(graphdef,
                                           output_dir,
                                           self.model_name + '_frozen.pb',
                                           as_text=False)
            logging.info('Frozen graph saved at %s', proto_path)

        if tflite:
            shape = (self.batch_size, *image_size, 3)
            input_spec = tf.TensorSpec(shape=shape,
                                       dtype=input_spec.dtype,
                                       name=input_spec.name)
            # from_saved_model supports advanced converter features like op fusing.
            converter = tf.lite.TFLiteConverter.from_saved_model(output_dir)
            if tflite == 'FP32':
                converter.optimizations = [tf.lite.Optimize.DEFAULT]
                converter.target_spec.supported_types = [tf.float32]
            elif tflite == 'FP16':
                converter.optimizations = [tf.lite.Optimize.DEFAULT]
                converter.target_spec.supported_types = [tf.float16]
            elif tflite == 'INT8':
                # Enables MLIR-based post-training quantization.
                converter.experimental_new_quantizer = True
                if file_pattern:
                    config = hparams_config.get_efficientdet_config(
                        self.model_name)
                    config.override(self.params)
                    ds = dataloader.InputReader(file_pattern,
                                                is_training=False,
                                                max_instances_per_image=config.
                                                max_instances_per_image)(
                                                    config,
                                                    batch_size=self.batch_size)

                    def representative_dataset_gen():
                        for image, _ in ds.take(num_calibration_steps):
                            yield [image]
                else:  # Used for debugging, can remove later.
                    logging.warn(
                        'Use real representative dataset instead of fake ones.'
                    )
                    num_calibration_steps = 10

                    def representative_dataset_gen(
                    ):  # rewrite this for real data.
                        for _ in range(num_calibration_steps):
                            yield [tf.ones(shape, dtype=input_spec.dtype)]

                converter.representative_dataset = representative_dataset_gen
                converter.optimizations = [tf.lite.Optimize.DEFAULT]
                converter.inference_input_type = tf.uint8
                # TFLite's custom NMS op isn't supported by post-training quant,
                # so we add TFLITE_BUILTINS as well.
                supported_ops = [
                    tf.lite.OpsSet.TFLITE_BUILTINS_INT8,
                    tf.lite.OpsSet.TFLITE_BUILTINS
                ]
                converter.target_spec.supported_ops = supported_ops

            else:
                raise ValueError(
                    f'Invalid tflite {tflite}: must be FP32, FP16, INT8.')

            tflite_path = os.path.join(output_dir, tflite.lower() + '.tflite')
            tflite_model = converter.convert()
            tf.io.gfile.GFile(tflite_path, 'wb').write(tflite_model)
            logging.info('TFLite is saved at %s', tflite_path)

        if tensorrt:
            trt_path = os.path.join(output_dir, 'tensorrt_' + tensorrt.lower())
            conversion_params = tf.experimental.tensorrt.ConversionParams(
                max_workspace_size_bytes=(2 << 20),
                maximum_cached_engines=1,
                precision_mode=tensorrt.upper())
            converter = tf.experimental.tensorrt.Converter(
                output_dir, conversion_params=conversion_params)
            converter.convert()
            converter.save(trt_path)
            logging.info('TensorRT model is saved at %s', trt_path)
Ejemplo n.º 9
0
    def __init__(self,
                 model_name=None,
                 config=None,
                 name='',
                 feature_only=False):
        """Initialize model."""
        super().__init__(name=name)

        config = config or hparams_config.get_efficientdet_config(model_name)
        self.config = config

        # Backbone.
        backbone_name = config.backbone_name
        is_training_bn = config.is_training_bn
        if 'efficientnet' in backbone_name:
            override_params = {
                'batch_norm':
                utils.batch_norm_class(is_training_bn, config.strategy),
                'relu_fn':
                functools.partial(utils.activation_fn,
                                  act_type=config.act_type),
                'grad_checkpoint':
                self.config.grad_checkpoint
            }
            if 'b0' in backbone_name:
                override_params['survival_prob'] = 0.0
            if config.backbone_config is not None:
                override_params['blocks_args'] = (
                    efficientnet_builder.BlockDecoder().encode(
                        config.backbone_config.blocks))
            override_params['data_format'] = config.data_format
            self.backbone = backbone_factory.get_model(
                backbone_name, override_params=override_params)

        # Feature network.
        self.resample_layers = []  # additional resampling layers.
        for level in range(6, config.max_level + 1):
            # Adds a coarser level by downsampling the last feature map.
            self.resample_layers.append(
                ResampleFeatureMap(
                    feat_level=(level - config.min_level),
                    target_num_channels=config.fpn_num_filters,
                    apply_bn=config.apply_bn_for_resampling,
                    is_training_bn=config.is_training_bn,
                    conv_after_downsample=config.conv_after_downsample,
                    strategy=config.strategy,
                    data_format=config.data_format,
                    model_optimizations=config.model_optimizations,
                    name='resample_p%d' % level,
                ))
        self.fpn_cells = FPNCells(config)

        # class/box output prediction network.
        num_anchors = len(config.aspect_ratios) * config.num_scales
        num_filters = config.fpn_num_filters
        for head in config.heads:
            if head == 'object_detection':
                self.class_net = ClassNet(
                    num_classes=config.num_classes,
                    num_anchors=num_anchors,
                    num_filters=num_filters,
                    min_level=config.min_level,
                    max_level=config.max_level,
                    is_training_bn=config.is_training_bn,
                    act_type=config.act_type,
                    repeats=config.box_class_repeats,
                    separable_conv=config.separable_conv,
                    survival_prob=config.survival_prob,
                    strategy=config.strategy,
                    grad_checkpoint=config.grad_checkpoint,
                    data_format=config.data_format,
                    feature_only=feature_only)

                self.box_net = BoxNet(num_anchors=num_anchors,
                                      num_filters=num_filters,
                                      min_level=config.min_level,
                                      max_level=config.max_level,
                                      is_training_bn=config.is_training_bn,
                                      act_type=config.act_type,
                                      repeats=config.box_class_repeats,
                                      separable_conv=config.separable_conv,
                                      survival_prob=config.survival_prob,
                                      strategy=config.strategy,
                                      grad_checkpoint=config.grad_checkpoint,
                                      data_format=config.data_format,
                                      feature_only=feature_only)

            if head == 'segmentation':
                self.seg_head = SegmentationHead(
                    num_classes=config.seg_num_classes,
                    num_filters=num_filters,
                    min_level=config.min_level,
                    max_level=config.max_level,
                    is_training_bn=config.is_training_bn,
                    act_type=config.act_type,
                    strategy=config.strategy,
                    data_format=config.data_format)
Ejemplo n.º 10
0
def main(_):
    config = hparams_config.get_efficientdet_config(FLAGS.model_name)
    config.override(FLAGS.hparams)
    config.val_json_file = FLAGS.val_json_file
    config.nms_configs.max_nms_inputs = anchors.MAX_DETECTION_POINTS
    config.drop_remainder = False  # eval all examples w/o drop.
    config.image_size = utils.parse_image_size(config['image_size'])

    # Evaluator for AP calculation.
    label_map = label_util.get_label_map(config.label_map)
    evaluator = coco_metric.EvaluationMetric(filename=config.val_json_file,
                                             label_map=label_map)

    # dataset
    batch_size = 1
    ds = dataloader.InputReader(
        FLAGS.val_file_pattern,
        is_training=False,
        max_instances_per_image=config.max_instances_per_image)(
            config, batch_size=batch_size)
    eval_samples = FLAGS.eval_samples
    if eval_samples:
        ds = ds.take((eval_samples + batch_size - 1) // batch_size)

    # Network
    lite_runner = LiteRunner(FLAGS.tflite_path, FLAGS.only_network)
    eval_samples = FLAGS.eval_samples or 5000
    pbar = tf.keras.utils.Progbar(
        (eval_samples + batch_size - 1) // batch_size)
    for i, (images, labels) in enumerate(ds):
        if not FLAGS.only_network:
            nms_boxes_bs, nms_classes_bs, nms_scores_bs, _ = lite_runner.run(
                images)
            nms_classes_bs += postprocess.CLASS_OFFSET

            height, width = utils.parse_image_size(config.image_size)
            normalize_factor = tf.constant([height, width, height, width],
                                           dtype=tf.float32)
            nms_boxes_bs *= normalize_factor
            if labels['image_scales'] is not None:
                scales = tf.expand_dims(
                    tf.expand_dims(labels['image_scales'], -1), -1)
                nms_boxes_bs = nms_boxes_bs * tf.cast(scales,
                                                      nms_boxes_bs.dtype)
            detections = postprocess.generate_detections_from_nms_output(
                nms_boxes_bs, nms_classes_bs, nms_scores_bs,
                labels['source_ids'])
        else:
            cls_outputs, box_outputs = lite_runner.run(images)
            detections = postprocess.generate_detections(
                config,
                cls_outputs,
                box_outputs,
                labels['image_scales'],
                labels['source_ids'],
                pre_class_nms=FLAGS.pre_class_nms)

        detections = postprocess.transform_detections(detections)
        evaluator.update_state(labels['groundtruth_data'].numpy(),
                               detections.numpy())
        pbar.update(i)

    # compute the final eval results.
    metrics = evaluator.result()
    metric_dict = {}
    for i, name in enumerate(evaluator.metric_names):
        metric_dict[name] = metrics[i]

    if label_map:
        for i, cid in enumerate(sorted(label_map.keys())):
            name = 'AP_/%s' % label_map[cid]
            metric_dict[name] = metrics[i + len(evaluator.metric_names)]
    print(FLAGS.model_name, metric_dict)
Ejemplo n.º 11
0
def main(_):
    config = hparams_config.get_efficientdet_config(FLAGS.model_name)
    config.override(FLAGS.hparams)
    config.val_json_file = FLAGS.val_json_file
    config.nms_configs.max_nms_inputs = anchors.MAX_DETECTION_POINTS
    config.drop_remainder = False  # eval all examples w/o drop.
    config.image_size = utils.parse_image_size(config['image_size'])

    if config.strategy == 'tpu':
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        tf.config.experimental_connect_to_cluster(tpu_cluster_resolver)
        tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver)
        ds_strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver)
        logging.info('All devices: %s', tf.config.list_logical_devices('TPU'))
    elif config.strategy == 'gpus':
        ds_strategy = tf.distribute.MirroredStrategy()
        logging.info('All devices: %s', tf.config.list_physical_devices('GPU'))
    else:
        if tf.config.list_physical_devices('GPU'):
            ds_strategy = tf.distribute.OneDeviceStrategy('device:GPU:0')
        else:
            ds_strategy = tf.distribute.OneDeviceStrategy('device:CPU:0')

    with ds_strategy.scope():
        # Network
        model = efficientdet_keras.EfficientDetNet(config=config)
        model.build((None, *config.image_size, 3))
        util_keras.restore_ckpt(model,
                                tf.train.latest_checkpoint(FLAGS.model_dir),
                                config.moving_average_decay,
                                skip_mismatch=False)

        @tf.function
        def model_fn(images, labels):
            cls_outputs, box_outputs = model(images, training=False)
            detections = postprocess.generate_detections(
                config, cls_outputs, box_outputs, labels['image_scales'],
                labels['source_ids'])
            tf.numpy_function(evaluator.update_state, [
                labels['groundtruth_data'],
                postprocess.transform_detections(detections)
            ], [])

        # Evaluator for AP calculation.
        label_map = label_util.get_label_map(config.label_map)
        evaluator = coco_metric.EvaluationMetric(filename=config.val_json_file,
                                                 label_map=label_map)

        # dataset
        batch_size = FLAGS.batch_size  # global batch size.
        ds = dataloader.InputReader(
            FLAGS.val_file_pattern,
            is_training=False,
            max_instances_per_image=config.max_instances_per_image)(
                config, batch_size=batch_size)
        if FLAGS.eval_samples:
            ds = ds.take((FLAGS.eval_samples + batch_size - 1) // batch_size)
        ds = ds_strategy.experimental_distribute_dataset(ds)

        # evaluate all images.
        eval_samples = FLAGS.eval_samples or 5000
        pbar = tf.keras.utils.Progbar(
            (eval_samples + batch_size - 1) // batch_size)
        for i, (images, labels) in enumerate(ds):
            ds_strategy.run(model_fn, (images, labels))
            pbar.update(i)

    # compute the final eval results.
    metrics = evaluator.result()
    metric_dict = {}
    for i, name in enumerate(evaluator.metric_names):
        metric_dict[name] = metrics[i]

    if label_map:
        for i, cid in enumerate(sorted(label_map.keys())):
            name = 'AP_/%s' % label_map[cid]
            metric_dict[name] = metrics[i + len(evaluator.metric_names)]
    print(FLAGS.model_name, metric_dict)