예제 #1
0
 def create_model(self) -> tf.keras.Model:
   """Creates the EfficientDet model."""
   return train_lib.EfficientDetNetTrainHub(
       config=self.config, hub_module_url=self.uri)
예제 #2
0
def main(_):
    # Parse and override hparams
    config = hparams_config.get_detection_config(FLAGS.model_name)
    config.override(FLAGS.hparams)
    if FLAGS.num_epochs:  # NOTE: remove this flag after updating all docs.
        config.num_epochs = FLAGS.num_epochs

    # Parse image size in case it is in string format.
    config.image_size = utils.parse_image_size(config.image_size)

    if FLAGS.use_xla and FLAGS.strategy != 'tpu':
        tf.config.optimizer.set_jit(True)
        for gpu in tf.config.list_physical_devices('GPU'):
            tf.config.experimental.set_memory_growth(gpu, True)

    if FLAGS.debug:
        tf.config.run_functions_eagerly(True)
        tf.debugging.set_log_device_placement(True)
        os.environ['TF_DETERMINISTIC_OPS'] = '1'
        tf.random.set_seed(FLAGS.tf_random_seed)
        logging.set_verbosity(logging.DEBUG)

    if FLAGS.strategy == 'tpu':
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        tf.config.experimental_connect_to_cluster(tpu_cluster_resolver)
        tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver)
        ds_strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver)
        logging.info('All devices: %s', tf.config.list_logical_devices('TPU'))
    elif FLAGS.strategy == 'gpus':
        ds_strategy = tf.distribute.MirroredStrategy()
        logging.info('All devices: %s', tf.config.list_physical_devices('GPU'))
    else:
        if tf.config.list_physical_devices('GPU'):
            ds_strategy = tf.distribute.OneDeviceStrategy('device:GPU:0')
        else:
            ds_strategy = tf.distribute.OneDeviceStrategy('device:CPU:0')

    steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.batch_size
    params = dict(profile=FLAGS.profile,
                  model_name=FLAGS.model_name,
                  steps_per_execution=FLAGS.steps_per_execution,
                  model_dir=FLAGS.model_dir,
                  steps_per_epoch=steps_per_epoch,
                  strategy=FLAGS.strategy,
                  batch_size=FLAGS.batch_size,
                  tf_random_seed=FLAGS.tf_random_seed,
                  debug=FLAGS.debug,
                  val_json_file=FLAGS.val_json_file,
                  eval_samples=FLAGS.eval_samples,
                  num_shards=ds_strategy.num_replicas_in_sync)
    config.override(params, True)
    # set mixed precision policy by keras api.
    precision = utils.get_precision(config.strategy, config.mixed_precision)
    policy = tf.keras.mixed_precision.Policy(precision)
    tf.keras.mixed_precision.set_global_policy(policy)

    def get_dataset(is_training, config):
        file_pattern = (FLAGS.train_file_pattern
                        if is_training else FLAGS.val_file_pattern)
        if not file_pattern:
            raise ValueError('No matching files.')

        return dataloader.InputReader(
            file_pattern,
            is_training=is_training,
            use_fake_data=FLAGS.use_fake_data,
            max_instances_per_image=config.max_instances_per_image,
            debug=FLAGS.debug)(config.as_dict())

    with ds_strategy.scope():
        if config.model_optimizations:
            tfmot.set_config(config.model_optimizations.as_dict())
        if FLAGS.hub_module_url:
            model = train_lib.EfficientDetNetTrainHub(
                config=config, hub_module_url=FLAGS.hub_module_url)
        else:
            model = train_lib.EfficientDetNetTrain(config=config)
        model = setup_model(model, config)
        if FLAGS.pretrained_ckpt and not FLAGS.hub_module_url:
            ckpt_path = tf.train.latest_checkpoint(FLAGS.pretrained_ckpt)
            util_keras.restore_ckpt(model, ckpt_path,
                                    config.moving_average_decay)
        init_experimental(config)
        if 'train' in FLAGS.mode:
            val_dataset = get_dataset(False,
                                      config) if 'eval' in FLAGS.mode else None
            model.fit(
                get_dataset(True, config),
                epochs=config.num_epochs,
                steps_per_epoch=steps_per_epoch,
                callbacks=train_lib.get_callbacks(config.as_dict(),
                                                  val_dataset),
                validation_data=val_dataset,
                validation_steps=(FLAGS.eval_samples // FLAGS.batch_size))
        else:
            # Continuous eval.
            for ckpt in tf.train.checkpoints_iterator(FLAGS.model_dir,
                                                      min_interval_secs=180):
                logging.info('Starting to evaluate.')
                # Terminate eval job when final checkpoint is reached.
                try:
                    current_epoch = int(os.path.basename(ckpt).split('-')[1])
                except IndexError:
                    current_epoch = 0

                val_dataset = get_dataset(False, config)
                logging.info('start loading model.')
                model.load_weights(tf.train.latest_checkpoint(FLAGS.model_dir))
                logging.info('finish loading model.')
                coco_eval = train_lib.COCOCallback(val_dataset, 1)
                coco_eval.set_model(model)
                eval_results = coco_eval.on_epoch_end(current_epoch)
                logging.info('eval results for %s: %s', ckpt, eval_results)

                try:
                    utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)
                except tf.errors.NotFoundError:
                    # Checkpoint might be not already deleted by the time eval finished.
                    logging.info('Checkpoint %s no longer exists, skipping.',
                                 ckpt)

                if current_epoch >= config.num_epochs or not current_epoch:
                    logging.info('Eval epoch %d / %d', current_epoch,
                                 config.num_epochs)
                    break