def train(self, model: tf.keras.Model, train_dataset: tf.data.Dataset, steps_per_epoch: int, val_dataset: Optional[tf.data.Dataset], validation_steps: int, epochs: Optional[int] = None, batch_size: Optional[int] = None, val_json_file: Optional[str] = None) -> tf.keras.Model: """Run EfficientDet training.""" config = self.config if not epochs: epochs = config.num_epochs if not batch_size: batch_size = config.batch_size config.update( dict( steps_per_epoch=steps_per_epoch, eval_samples=batch_size * validation_steps, val_json_file=val_json_file, batch_size=batch_size)) train.setup_model(model, config) train.init_experimental(config) model.fit( train_dataset, epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=train_lib.get_callbacks(config.as_dict(), val_dataset), validation_data=val_dataset, validation_steps=validation_steps) return model
def main(_): # Parse and override hparams config = hparams_config.get_detection_config(FLAGS.model_name) config.override(FLAGS.hparams) if FLAGS.num_epochs: # NOTE: remove this flag after updating all docs. config.num_epochs = FLAGS.num_epochs # Parse image size in case it is in string format. config.image_size = utils.parse_image_size(config.image_size) if FLAGS.use_xla and FLAGS.strategy != 'tpu': tf.config.optimizer.set_jit(True) for gpu in tf.config.list_physical_devices('GPU'): tf.config.experimental.set_memory_growth(gpu, True) if FLAGS.debug: tf.config.run_functions_eagerly(True) tf.debugging.set_log_device_placement(True) os.environ['TF_DETERMINISTIC_OPS'] = '1' tf.random.set_seed(FLAGS.tf_random_seed) logging.set_verbosity(logging.DEBUG) if FLAGS.strategy == 'tpu': tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) tf.config.experimental_connect_to_cluster(tpu_cluster_resolver) tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver) ds_strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver) logging.info('All devices: %s', tf.config.list_logical_devices('TPU')) elif FLAGS.strategy == 'gpus': ds_strategy = tf.distribute.MirroredStrategy() logging.info('All devices: %s', tf.config.list_physical_devices('GPU')) else: if tf.config.list_physical_devices('GPU'): ds_strategy = tf.distribute.OneDeviceStrategy('device:GPU:0') else: ds_strategy = tf.distribute.OneDeviceStrategy('device:CPU:0') steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.batch_size params = dict(profile=FLAGS.profile, model_name=FLAGS.model_name, steps_per_execution=FLAGS.steps_per_execution, model_dir=FLAGS.model_dir, steps_per_epoch=steps_per_epoch, strategy=FLAGS.strategy, batch_size=FLAGS.batch_size, tf_random_seed=FLAGS.tf_random_seed, debug=FLAGS.debug, val_json_file=FLAGS.val_json_file, eval_samples=FLAGS.eval_samples, num_shards=ds_strategy.num_replicas_in_sync) config.override(params, True) # set mixed precision policy by keras api. precision = utils.get_precision(config.strategy, config.mixed_precision) policy = tf.keras.mixed_precision.Policy(precision) tf.keras.mixed_precision.set_global_policy(policy) def get_dataset(is_training, config): file_pattern = (FLAGS.train_file_pattern if is_training else FLAGS.val_file_pattern) if not file_pattern: raise ValueError('No matching files.') return dataloader.InputReader( file_pattern, is_training=is_training, use_fake_data=FLAGS.use_fake_data, max_instances_per_image=config.max_instances_per_image, debug=FLAGS.debug)(config.as_dict()) with ds_strategy.scope(): if config.model_optimizations: tfmot.set_config(config.model_optimizations.as_dict()) if FLAGS.hub_module_url: model = train_lib.EfficientDetNetTrainHub( config=config, hub_module_url=FLAGS.hub_module_url) else: model = train_lib.EfficientDetNetTrain(config=config) model = setup_model(model, config) if FLAGS.pretrained_ckpt and not FLAGS.hub_module_url: ckpt_path = tf.train.latest_checkpoint(FLAGS.pretrained_ckpt) util_keras.restore_ckpt(model, ckpt_path, config.moving_average_decay) init_experimental(config) if 'train' in FLAGS.mode: val_dataset = get_dataset(False, config) if 'eval' in FLAGS.mode else None model.fit( get_dataset(True, config), epochs=config.num_epochs, steps_per_epoch=steps_per_epoch, callbacks=train_lib.get_callbacks(config.as_dict(), val_dataset), validation_data=val_dataset, validation_steps=(FLAGS.eval_samples // FLAGS.batch_size)) else: # Continuous eval. for ckpt in tf.train.checkpoints_iterator(FLAGS.model_dir, min_interval_secs=180): logging.info('Starting to evaluate.') # Terminate eval job when final checkpoint is reached. try: current_epoch = int(os.path.basename(ckpt).split('-')[1]) except IndexError: current_epoch = 0 val_dataset = get_dataset(False, config) logging.info('start loading model.') model.load_weights(tf.train.latest_checkpoint(FLAGS.model_dir)) logging.info('finish loading model.') coco_eval = train_lib.COCOCallback(val_dataset, 1) coco_eval.set_model(model) eval_results = coco_eval.on_epoch_end(current_epoch) logging.info('eval results for %s: %s', ckpt, eval_results) try: utils.archive_ckpt(eval_results, eval_results['AP'], ckpt) except tf.errors.NotFoundError: # Checkpoint might be not already deleted by the time eval finished. logging.info('Checkpoint %s no longer exists, skipping.', ckpt) if current_epoch >= config.num_epochs or not current_epoch: logging.info('Eval epoch %d / %d', current_epoch, config.num_epochs) break