示例#1
0
    def __init__(self,
                 model_name: Text,
                 ckpt_path: Text = None,
                 batch_size: int = 1,
                 only_network: bool = False,
                 model_params: Dict[Text, Any] = None):
        """Initialize the inference driver.

    Args:
      model_name: target model name, such as efficientdet-d0.
      ckpt_path: checkpoint path, such as /tmp/efficientdet-d0/.
      batch_size: batch size for inference.
      only_network: only use the network without pre/post processing.
      model_params: model parameters for overriding the config.
    """
        super().__init__()
        self.model_name = model_name
        self.ckpt_path = ckpt_path
        self.batch_size = batch_size
        self.only_network = only_network

        self.params = hparams_config.get_detection_config(model_name).as_dict()

        if model_params:
            self.params.update(model_params)
        self.params.update(dict(is_training_bn=False))
        self.label_map = self.params.get('label_map', None)

        self._model = None

        mixed_precision = self.params.get('mixed_precision', None)
        precision = utils.get_precision(self.params.get('strategy', None),
                                        mixed_precision)
        policy = tf.keras.mixed_precision.Policy(precision)
        tf.keras.mixed_precision.set_global_policy(policy)
示例#2
0
    def _build_model(self, grad_checkpoint=False):
        tf.random.set_seed(1111)
        config = hparams_config.get_detection_config('efficientdet-d0')
        config.heads = ['object_detection', 'segmentation']
        config.batch_size = 1
        config.num_examples_per_epoch = 1
        config.model_dir = tempfile.mkdtemp()
        config.steps_per_epoch = 1
        config.mixed_precision = True
        config.grad_checkpoint = grad_checkpoint
        x = tf.ones((1, 512, 512, 3))
        labels = {
            'box_targets_%d' % i: tf.ones((1, 512 // 2**i, 512 // 2**i, 36))
            for i in range(3, 8)
        }
        labels.update({
            'cls_targets_%d' % i: tf.ones((1, 512 // 2**i, 512 // 2**i, 9),
                                          dtype=tf.int32)
            for i in range(3, 8)
        })
        labels.update({'image_masks': tf.ones((1, 128, 128, 1))})
        labels.update({'mean_num_positives': tf.constant([10.0])})

        params = config.as_dict()
        params['num_shards'] = 1
        params['steps_per_execution'] = 100
        params['model_dir'] = tempfile.mkdtemp()
        params['profile'] = False
        config.override(params, allow_new_keys=True)
        model = train_lib.EfficientDetNetTrain(config=config)
        model.build((1, 512, 512, 3))
        model.compile(
            optimizer=train_lib.get_optimizer(params),
            loss={
                train_lib.BoxLoss.__name__:
                train_lib.BoxLoss(params['delta'],
                                  reduction=tf.keras.losses.Reduction.NONE),
                train_lib.BoxIouLoss.__name__:
                train_lib.BoxIouLoss(params['iou_loss_type'],
                                     params['min_level'],
                                     params['max_level'],
                                     params['num_scales'],
                                     params['aspect_ratios'],
                                     params['anchor_scale'],
                                     params['image_size'],
                                     reduction=tf.keras.losses.Reduction.NONE),
                train_lib.FocalLoss.__name__:
                train_lib.FocalLoss(params['alpha'],
                                    params['gamma'],
                                    label_smoothing=params['label_smoothing'],
                                    reduction=tf.keras.losses.Reduction.NONE),
                tf.keras.losses.SparseCategoricalCrossentropy.__name__:
                tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
            })
        return params, x, labels, model
示例#3
0
def main(_):
  # Parse and override hparams
  config = hparams_config.get_detection_config(FLAGS.model_name)
  config.override(FLAGS.hparams)

  # Parse image size in case it is in string format.
  config.image_size = utils.parse_image_size(config.image_size)
  try:
    recordinspect = RecordInspect(config)
    recordinspect.visualize()
  except Exception as e:  # pylint: disable=broad-except
    logging.error(e)
  else:
    logging.info('Done, please find samples at %s', FLAGS.save_samples_dir)
示例#4
0
 def test_display_callback(self):
     config = hparams_config.get_detection_config('efficientdet-d0')
     config.batch_size = 1
     config.num_examples_per_epoch = 1
     config.model_dir = tempfile.mkdtemp()
     fake_image = tf.ones([512, 512, 3], dtype=tf.uint8)
     fake_jpeg = tf.image.encode_jpeg(fake_image)
     sample_image = os.path.join(config.model_dir + 'fake_image.jpg')
     tf.io.write_file(sample_image, fake_jpeg)
     display_callback = train_lib.DisplayCallback(sample_image,
                                                  config.model_dir,
                                                  update_freq=1)
     model = train_lib.EfficientDetNetTrain(config=config)
     model.build((1, 512, 512, 3))
     display_callback.set_model(model)
     display_callback.on_train_batch_end(0)
示例#5
0
    def __init__(self,
                 model_name: Text,
                 logdir: Text,
                 tensorrt: Text = False,
                 use_xla: bool = False,
                 ckpt_path: Text = None,
                 export_ckpt: Text = None,
                 saved_model_dir: Text = None,
                 tflite_path: Text = None,
                 batch_size: int = 1,
                 hparams: Text = '',
                 **kwargs):
        self.model_name = model_name
        self.logdir = logdir
        self.tensorrt = tensorrt
        self.use_xla = use_xla
        self.ckpt_path = ckpt_path
        self.export_ckpt = export_ckpt
        self.saved_model_dir = saved_model_dir
        self.tflite_path = tflite_path

        model_config = hparams_config.get_detection_config(model_name)
        model_config.override(hparams)  # Add custom overrides
        model_config.is_training_bn = False
        model_config.image_size = utils.parse_image_size(
            model_config.image_size)

        # If batch size is 0, then build a graph with dynamic batch size.
        self.batch_size = batch_size or None
        self.labels_shape = [batch_size, model_config.num_classes]

        # A hack to make flag consistent with nms configs.
        if kwargs.get('score_thresh', None):
            model_config.nms_configs.score_thresh = kwargs['score_thresh']
        if kwargs.get('nms_method', None):
            model_config.nms_configs.method = kwargs['nms_method']
        if kwargs.get('max_output_size', None):
            model_config.nms_configs.max_output_size = kwargs[
                'max_output_size']

        height, width = model_config.image_size
        if model_config.data_format == 'channels_first':
            self.inputs_shape = [batch_size, 3, height, width]
        else:
            self.inputs_shape = [batch_size, height, width, 3]

        self.model_config = model_config
示例#6
0
    def __init__(self,
                 model_name: Text,
                 ckpt_path: Text,
                 model_params: Dict[Text, Any] = None):
        """Initialize the inference driver.

    Args:
      model_name: target model name, such as efficientdet-d0.
      ckpt_path: checkpoint path, such as /tmp/efficientdet-d0/.
      model_params: model parameters for overriding the config.
    """
        self.model_name = model_name
        self.ckpt_path = ckpt_path
        self.params = hparams_config.get_detection_config(model_name).as_dict()
        if model_params:
            self.params.update(model_params)
        self.params.update(dict(is_training_bn=False))
        self.label_map = self.params.get('label_map', None)
示例#7
0
 def test_parser(self):
     tf.random.set_seed(111111)
     params = hparams_config.get_detection_config(
         'efficientdet-d0').as_dict()
     input_anchors = anchors.Anchors(params['min_level'],
                                     params['max_level'],
                                     params['num_scales'],
                                     params['aspect_ratios'],
                                     params['anchor_scale'],
                                     params['image_size'])
     anchor_labeler = anchors.AnchorLabeler(input_anchors,
                                            params['num_classes'])
     example_decoder = tf_example_decoder.TfExampleDecoder(
         regenerate_source_id=params['regenerate_source_id'])
     tfrecord_path = test_util.make_fake_tfrecord(self.get_temp_dir())
     dataset = tf.data.TFRecordDataset([tfrecord_path])
     value = next(iter(dataset))
     reader = dataloader.InputReader(tfrecord_path, True)
     result = reader.dataset_parser(value, example_decoder, anchor_labeler,
                                    params)
     self.assertEqual(len(result), 11)
示例#8
0
    def __init__(self,
                 model_name: Text,
                 ckpt_path: Text,
                 batch_size: int = 1,
                 use_xla: bool = False,
                 min_score_thresh: float = None,
                 max_boxes_to_draw: float = None,
                 line_thickness: int = None,
                 model_params: Dict[Text, Any] = None):
        """Initialize the inference driver.

    Args:
      model_name: target model name, such as efficientdet-d0.
      ckpt_path: checkpoint path, such as /tmp/efficientdet-d0/.
      batch_size: batch size for inference.
      use_xla: Whether run with xla optimization.
      min_score_thresh: minimal score threshold for filtering predictions.
      max_boxes_to_draw: the maximum number of boxes per image.
      line_thickness: the line thickness for drawing boxes.
      model_params: model parameters for overriding the config.
    """
        self.model_name = model_name
        self.ckpt_path = ckpt_path
        self.batch_size = batch_size

        self.params = hparams_config.get_detection_config(model_name).as_dict()

        if model_params:
            self.params.update(model_params)
        self.params.update(dict(is_training_bn=False))
        self.label_map = self.params.get('label_map', None)

        self.signitures = None
        self.sess = None
        self.use_xla = use_xla

        self.min_score_thresh = min_score_thresh
        self.max_boxes_to_draw = max_boxes_to_draw
        self.line_thickness = line_thickness
示例#9
0
def main(_):
    tf.config.run_functions_eagerly(FLAGS.debug)
    devices = tf.config.list_physical_devices('GPU')
    for device in devices:
        tf.config.experimental.set_memory_growth(device, True)

    model_config = hparams_config.get_detection_config(FLAGS.model_name)
    model_config.override(FLAGS.hparams)  # Add custom overrides
    model_config.is_training_bn = False
    if FLAGS.image_size != -1:
        model_config.image_size = FLAGS.image_size
    model_config.image_size = utils.parse_image_size(model_config.image_size)

    model_params = model_config.as_dict()
    ckpt_path_or_file = FLAGS.model_dir
    if tf.io.gfile.isdir(ckpt_path_or_file):
        ckpt_path_or_file = tf.train.latest_checkpoint(ckpt_path_or_file)
    driver = inference.ServingDriver(FLAGS.model_name, ckpt_path_or_file,
                                     FLAGS.batch_size or None,
                                     FLAGS.only_network, model_params)
    if FLAGS.mode == 'export':
        if not FLAGS.saved_model_dir:
            raise ValueError('Please specify --saved_model_dir=')
        model_dir = FLAGS.saved_model_dir
        if tf.io.gfile.exists(model_dir):
            tf.io.gfile.rmtree(model_dir)
        driver.export(model_dir, FLAGS.tensorrt, FLAGS.tflite,
                      FLAGS.file_pattern, FLAGS.num_calibration_steps)
        print('Model are exported to %s' % model_dir)
    elif FLAGS.mode == 'infer':
        image_file = tf.io.read_file(FLAGS.input_image)
        image_arrays = tf.io.decode_image(image_file)
        image_arrays.set_shape((None, None, 3))
        image_arrays = tf.expand_dims(image_arrays, axis=0)
        if FLAGS.saved_model_dir:
            driver.load(FLAGS.saved_model_dir)
            if FLAGS.saved_model_dir.endswith('.tflite'):
                image_arrays = tf.image.resize_with_pad(
                    image_arrays, *model_config.image_size)
                image_arrays = tf.cast(image_arrays, tf.uint8)
        detections_bs = driver.serve(image_arrays)
        boxes, scores, classes, _ = tf.nest.map_structure(
            np.array, detections_bs)
        raw_image = Image.fromarray(np.array(image_arrays)[0])
        img = driver.visualize(
            raw_image,
            boxes[0],
            classes[0],
            scores[0],
            min_score_thresh=model_config.nms_configs.score_thresh or 0.4,
            max_boxes_to_draw=model_config.nms_configs.max_output_size)
        output_image_path = os.path.join(FLAGS.output_image_dir, '0.jpg')
        Image.fromarray(img).save(output_image_path)
        print('writing file to %s' % output_image_path)
    elif FLAGS.mode == 'benchmark':
        if FLAGS.saved_model_dir:
            driver.load(FLAGS.saved_model_dir)

        batch_size = FLAGS.batch_size or 1
        if FLAGS.input_image:
            image_file = tf.io.read_file(FLAGS.input_image)
            image_arrays = tf.image.decode_image(image_file)
            image_arrays.set_shape((None, None, 3))
            image_arrays = tf.expand_dims(image_arrays, 0)
            if batch_size > 1:
                image_arrays = tf.tile(image_arrays, [batch_size, 1, 1, 1])
        else:
            # use synthetic data if no image is provided.
            image_arrays = tf.ones((batch_size, *model_config.image_size, 3),
                                   dtype=tf.uint8)
        if FLAGS.only_network:
            image_arrays = tf.image.convert_image_dtype(
                image_arrays, tf.float32)
            image_arrays = tf.image.resize(image_arrays,
                                           model_config.image_size)
        driver.benchmark(image_arrays, FLAGS.bm_runs, FLAGS.trace_filename)
    elif FLAGS.mode == 'dry':
        # transfer to tf2 format ckpt
        driver.build()
        if FLAGS.export_ckpt:
            driver.model.save_weights(FLAGS.export_ckpt)
    elif FLAGS.mode == 'video':
        import cv2  # pylint: disable=g-import-not-at-top
        if FLAGS.saved_model_dir:
            driver.load(FLAGS.saved_model_dir)
        cap = cv2.VideoCapture(FLAGS.input_video)
        if not cap.isOpened():
            print('Error opening input video: {}'.format(FLAGS.input_video))

        out_ptr = None
        if FLAGS.output_video:
            frame_width, frame_height = int(cap.get(3)), int(cap.get(4))
            out_ptr = cv2.VideoWriter(
                FLAGS.output_video, cv2.VideoWriter_fourcc('m', 'p', '4', 'v'),
                cap.get(5), (frame_width, frame_height))

        while cap.isOpened():
            # Capture frame-by-frame
            ret, frame = cap.read()
            if not ret:
                break

            raw_frames = np.array([frame])
            detections_bs = driver.serve(raw_frames)
            boxes, scores, classes, _ = tf.nest.map_structure(
                np.array, detections_bs)
            new_frame = driver.visualize(
                raw_frames[0],
                boxes[0],
                classes[0],
                scores[0],
                min_score_thresh=model_config.nms_configs.score_thresh or 0.4,
                max_boxes_to_draw=model_config.nms_configs.max_output_size)

            if out_ptr:
                # write frame into output file.
                out_ptr.write(new_frame)
            else:
                # show the frame online, mainly used for real-time speed test.
                cv2.imshow('Frame', new_frame)
                # Press Q on keyboard to  exit
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
示例#10
0
文件: main.py 项目: matthewygf/automl
def main(_):
  if FLAGS.strategy == 'tpu':
    tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    tpu_grpc_url = tpu_cluster_resolver.get_master()
    tf.Session.reset(tpu_grpc_url)
  else:
    tpu_cluster_resolver = None

  # Check data path
  if FLAGS.mode in ('train', 'train_and_eval'):
    if FLAGS.train_file_pattern is None:
      raise RuntimeError('Must specify --train_file_pattern for train.')
  if FLAGS.mode in ('eval', 'train_and_eval'):
    if FLAGS.val_file_pattern is None:
      raise RuntimeError('Must specify --val_file_pattern for eval.')

  # Parse and override hparams
  config = hparams_config.get_detection_config(FLAGS.model_name)
  config.override(FLAGS.hparams)
  if FLAGS.num_epochs:  # NOTE: remove this flag after updating all docs.
    config.num_epochs = FLAGS.num_epochs

  # Parse image size in case it is in string format.
  config.image_size = utils.parse_image_size(config.image_size)

  # The following is for spatial partitioning. `features` has one tensor while
  # `labels` had 4 + (`max_level` - `min_level` + 1) * 2 tensors. The input
  # partition is performed on `features` and all partitionable tensors of
  # `labels`, see the partition logic below.
  # In the TPUEstimator context, the meaning of `shard` and `replica` is the
  # same; follwing the API, here has mixed use of both.
  if FLAGS.use_spatial_partition:
    # Checks input_partition_dims agrees with num_cores_per_replica.
    if FLAGS.num_cores_per_replica != np.prod(FLAGS.input_partition_dims):
      raise RuntimeError('--num_cores_per_replica must be a product of array'
                         'elements in --input_partition_dims.')

    labels_partition_dims = {
        'mean_num_positives': None,
        'source_ids': None,
        'groundtruth_data': None,
        'image_scales': None,
        'image_masks': None,
    }
    # The Input Partition Logic: We partition only the partition-able tensors.
    feat_sizes = utils.get_feat_sizes(
        config.get('image_size'), config.get('max_level'))
    for level in range(config.get('min_level'), config.get('max_level') + 1):

      def _can_partition(spatial_dim):
        partitionable_index = np.where(
            spatial_dim % np.array(FLAGS.input_partition_dims) == 0)
        return len(partitionable_index[0]) == len(FLAGS.input_partition_dims)

      spatial_dim = feat_sizes[level]
      if _can_partition(spatial_dim['height']) and _can_partition(
          spatial_dim['width']):
        labels_partition_dims['box_targets_%d' %
                              level] = FLAGS.input_partition_dims
        labels_partition_dims['cls_targets_%d' %
                              level] = FLAGS.input_partition_dims
      else:
        labels_partition_dims['box_targets_%d' % level] = None
        labels_partition_dims['cls_targets_%d' % level] = None
    num_cores_per_replica = FLAGS.num_cores_per_replica
    input_partition_dims = [FLAGS.input_partition_dims, labels_partition_dims]
    num_shards = FLAGS.num_cores // num_cores_per_replica
  else:
    num_cores_per_replica = None
    input_partition_dims = None
    num_shards = FLAGS.num_cores

  params = dict(
      config.as_dict(),
      model_name=FLAGS.model_name,
      iterations_per_loop=FLAGS.iterations_per_loop,
      model_dir=FLAGS.model_dir,
      num_shards=num_shards,
      num_examples_per_epoch=FLAGS.num_examples_per_epoch,
      strategy=FLAGS.strategy,
      backbone_ckpt=FLAGS.backbone_ckpt,
      ckpt=FLAGS.ckpt,
      val_json_file=FLAGS.val_json_file,
      testdev_dir=FLAGS.testdev_dir,
      profile=FLAGS.profile,
      mode=FLAGS.mode)
  config_proto = tf.ConfigProto(
      allow_soft_placement=True, log_device_placement=False)
  if FLAGS.strategy != 'tpu':
    if FLAGS.use_xla:
      config_proto.graph_options.optimizer_options.global_jit_level = (
          tf.OptimizerOptions.ON_1)
      config_proto.gpu_options.allow_growth = True

  model_dir = FLAGS.model_dir
  model_fn_instance = det_model_fn.get_model_fn(FLAGS.model_name)
  max_instances_per_image = config.max_instances_per_image
  if FLAGS.eval_samples:
    eval_steps = int((FLAGS.eval_samples + FLAGS.eval_batch_size - 1) //
                     FLAGS.eval_batch_size)
  else:
    eval_steps = None
  total_examples = int(config.num_epochs * FLAGS.num_examples_per_epoch)
  train_steps = total_examples // FLAGS.train_batch_size
  logging.info(params)

  if not tf.io.gfile.exists(model_dir):
    tf.io.gfile.makedirs(model_dir)

  config_file = os.path.join(model_dir, 'config.yaml')
  if not tf.io.gfile.exists(config_file):
    tf.io.gfile.GFile(config_file, 'w').write(str(config))

  train_input_fn = dataloader.InputReader(
      FLAGS.train_file_pattern,
      is_training=True,
      use_fake_data=FLAGS.use_fake_data,
      max_instances_per_image=max_instances_per_image)
  eval_input_fn = dataloader.InputReader(
      FLAGS.val_file_pattern,
      is_training=False,
      use_fake_data=FLAGS.use_fake_data,
      max_instances_per_image=max_instances_per_image)

  if FLAGS.strategy == 'tpu':
    tpu_config = tf.estimator.tpu.TPUConfig(
        FLAGS.iterations_per_loop if FLAGS.strategy == 'tpu' else 1,
        num_cores_per_replica=num_cores_per_replica,
        input_partition_dims=input_partition_dims,
        per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig
        .PER_HOST_V2)
    run_config = tf.estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=model_dir,
        log_step_count_steps=FLAGS.iterations_per_loop,
        session_config=config_proto,
        tpu_config=tpu_config,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tf_random_seed=FLAGS.tf_random_seed,
    )
    # TPUEstimator can do both train and eval.
    train_est = tf.estimator.tpu.TPUEstimator(
        model_fn=model_fn_instance,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        config=run_config,
        params=params)
    eval_est = train_est
  else:
    strategy = None
    if FLAGS.strategy == 'gpus':
      strategy = tf.distribute.MirroredStrategy()
    run_config = tf.estimator.RunConfig(
        model_dir=model_dir,
        train_distribute=strategy,
        log_step_count_steps=FLAGS.iterations_per_loop,
        session_config=config_proto,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tf_random_seed=FLAGS.tf_random_seed,
    )

    def get_estimator(global_batch_size):
      params['num_shards'] = getattr(strategy, 'num_replicas_in_sync', 1)
      params['batch_size'] = global_batch_size // params['num_shards']
      return tf.estimator.Estimator(
          model_fn=model_fn_instance, config=run_config, params=params)

    # train and eval need different estimator due to different batch size.
    train_est = get_estimator(FLAGS.train_batch_size)
    eval_est = get_estimator(FLAGS.eval_batch_size)

  # start train/eval flow.
  if FLAGS.mode == 'train':
    train_est.train(input_fn=train_input_fn, max_steps=train_steps)
    if FLAGS.eval_after_train:
      eval_est.evaluate(input_fn=eval_input_fn, steps=eval_steps)

  elif FLAGS.mode == 'eval':
    # Run evaluation when there's a new checkpoint
    for ckpt in tf.train.checkpoints_iterator(
        FLAGS.model_dir,
        min_interval_secs=FLAGS.min_eval_interval,
        timeout=FLAGS.eval_timeout):

      logging.info('Starting to evaluate.')
      try:
        eval_results = eval_est.evaluate(eval_input_fn, steps=eval_steps)
        # Terminate eval job when final checkpoint is reached.
        try:
          current_step = int(os.path.basename(ckpt).split('-')[1])
        except IndexError:
          logging.info('%s has no global step info: stop!', ckpt)
          break

        utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)
        if current_step >= train_steps:
          logging.info('Eval finished step %d/%d', current_step, train_steps)
          break

      except tf.errors.NotFoundError:
        # Checkpoint might be not already deleted by the time eval finished.
        # We simply skip ssuch case.
        logging.info('Checkpoint %s no longer exists, skipping.', ckpt)

  elif FLAGS.mode == 'train_and_eval':
    ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
    try:
      step = int(os.path.basename(ckpt).split('-')[1])
      current_epoch = (
          step * FLAGS.train_batch_size // FLAGS.num_examples_per_epoch)
      logging.info('found ckpt at step %d (epoch %d)', step, current_epoch)
    except (IndexError, TypeError):
      logging.info('Folder %s has no ckpt with valid step.', FLAGS.model_dir)
      current_epoch = 0

    def run_train_and_eval(e):
      print('\n   =====> Starting training, epoch: %d.' % e)
      train_est.train(
          input_fn=train_input_fn,
          max_steps=e * FLAGS.num_examples_per_epoch // FLAGS.train_batch_size)
      print('\n   =====> Starting evaluation, epoch: %d.' % e)
      eval_results = eval_est.evaluate(input_fn=eval_input_fn, steps=eval_steps)
      ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
      utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)

    epochs_per_cycle = 1  # higher number has less graph construction overhead.
    for e in range(current_epoch + 1, config.num_epochs + 1, epochs_per_cycle):
      if FLAGS.run_epoch_in_child_process:
        p = multiprocessing.Process(target=run_train_and_eval, args=(e,))
        p.start()
        p.join()
        if p.exitcode != 0:
          return p.exitcode
      else:
        tf.reset_default_graph()
        run_train_and_eval(e)

  else:
    logging.info('Invalid mode: %s', FLAGS.mode)
 def __init__(self, model_name):
     self.model_name = model_name
     config = hparams_config.get_detection_config(model_name)
     config.image_size = utils.parse_image_size(config.image_size)
     config.update({'debug': False})
     self.config = config
示例#12
0
def main(_):
    # Parse and override hparams
    config = hparams_config.get_detection_config(FLAGS.model_name)
    config.override(FLAGS.hparams)
    if FLAGS.num_epochs:  # NOTE: remove this flag after updating all docs.
        config.num_epochs = FLAGS.num_epochs

    # Parse image size in case it is in string format.
    config.image_size = utils.parse_image_size(config.image_size)

    if FLAGS.use_xla and FLAGS.strategy != 'tpu':
        tf.config.optimizer.set_jit(True)
        for gpu in tf.config.list_physical_devices('GPU'):
            tf.config.experimental.set_memory_growth(gpu, True)

    if FLAGS.debug:
        tf.config.run_functions_eagerly(True)
        tf.debugging.set_log_device_placement(True)
        os.environ['TF_DETERMINISTIC_OPS'] = '1'
        tf.random.set_seed(FLAGS.tf_random_seed)
        logging.set_verbosity(logging.DEBUG)

    if FLAGS.strategy == 'tpu':
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        tf.config.experimental_connect_to_cluster(tpu_cluster_resolver)
        tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver)
        ds_strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver)
        logging.info('All devices: %s', tf.config.list_logical_devices('TPU'))
    elif FLAGS.strategy == 'gpus':
        ds_strategy = tf.distribute.MirroredStrategy()
        logging.info('All devices: %s', tf.config.list_physical_devices('GPU'))
    else:
        if tf.config.list_physical_devices('GPU'):
            ds_strategy = tf.distribute.OneDeviceStrategy('device:GPU:0')
        else:
            ds_strategy = tf.distribute.OneDeviceStrategy('device:CPU:0')

    steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.batch_size
    params = dict(profile=FLAGS.profile,
                  model_name=FLAGS.model_name,
                  steps_per_execution=FLAGS.steps_per_execution,
                  model_dir=FLAGS.model_dir,
                  steps_per_epoch=steps_per_epoch,
                  strategy=FLAGS.strategy,
                  batch_size=FLAGS.batch_size,
                  tf_random_seed=FLAGS.tf_random_seed,
                  debug=FLAGS.debug,
                  val_json_file=FLAGS.val_json_file,
                  eval_samples=FLAGS.eval_samples,
                  num_shards=ds_strategy.num_replicas_in_sync)
    config.override(params, True)
    # set mixed precision policy by keras api.
    precision = utils.get_precision(config.strategy, config.mixed_precision)
    policy = tf.keras.mixed_precision.Policy(precision)
    tf.keras.mixed_precision.set_global_policy(policy)

    def get_dataset(is_training, config):
        file_pattern = (FLAGS.train_file_pattern
                        if is_training else FLAGS.val_file_pattern)
        if not file_pattern:
            raise ValueError('No matching files.')

        return dataloader.InputReader(
            file_pattern,
            is_training=is_training,
            use_fake_data=FLAGS.use_fake_data,
            max_instances_per_image=config.max_instances_per_image,
            debug=FLAGS.debug)(config.as_dict())

    with ds_strategy.scope():
        if config.model_optimizations:
            tfmot.set_config(config.model_optimizations.as_dict())
        if FLAGS.hub_module_url:
            model = train_lib.EfficientDetNetTrainHub(
                config=config, hub_module_url=FLAGS.hub_module_url)
        else:
            model = train_lib.EfficientDetNetTrain(config=config)
        model = setup_model(model, config)
        if FLAGS.pretrained_ckpt and not FLAGS.hub_module_url:
            ckpt_path = tf.train.latest_checkpoint(FLAGS.pretrained_ckpt)
            util_keras.restore_ckpt(model, ckpt_path,
                                    config.moving_average_decay)
        init_experimental(config)
        if 'train' in FLAGS.mode:
            val_dataset = get_dataset(False,
                                      config) if 'eval' in FLAGS.mode else None
            model.fit(
                get_dataset(True, config),
                epochs=config.num_epochs,
                steps_per_epoch=steps_per_epoch,
                callbacks=train_lib.get_callbacks(config.as_dict(),
                                                  val_dataset),
                validation_data=val_dataset,
                validation_steps=(FLAGS.eval_samples // FLAGS.batch_size))
        else:
            # Continuous eval.
            for ckpt in tf.train.checkpoints_iterator(FLAGS.model_dir,
                                                      min_interval_secs=180):
                logging.info('Starting to evaluate.')
                # Terminate eval job when final checkpoint is reached.
                try:
                    current_epoch = int(os.path.basename(ckpt).split('-')[1])
                except IndexError:
                    current_epoch = 0

                val_dataset = get_dataset(False, config)
                logging.info('start loading model.')
                model.load_weights(tf.train.latest_checkpoint(FLAGS.model_dir))
                logging.info('finish loading model.')
                coco_eval = train_lib.COCOCallback(val_dataset, 1)
                coco_eval.set_model(model)
                eval_results = coco_eval.on_epoch_end(current_epoch)
                logging.info('eval results for %s: %s', ckpt, eval_results)

                try:
                    utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)
                except tf.errors.NotFoundError:
                    # Checkpoint might be not already deleted by the time eval finished.
                    logging.info('Checkpoint %s no longer exists, skipping.',
                                 ckpt)

                if current_epoch >= config.num_epochs or not current_epoch:
                    logging.info('Eval epoch %d / %d', current_epoch,
                                 config.num_epochs)
                    break