def run_executor(params, mode, checkpoint_path=None, train_input_fn=None, eval_input_fn=None, callbacks=None, prebuilt_strategy=None): """Runs the object detection model on distribution strategy defined by the user.""" if params.architecture.use_bfloat16: policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_bfloat16') tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) model_builder = model_factory.model_generator(params) if prebuilt_strategy is not None: strategy = prebuilt_strategy else: strategy_config = params.strategy_config distribution_utils.configure_cluster(strategy_config.worker_hosts, strategy_config.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=params.strategy_type, num_gpus=strategy_config.num_gpus, all_reduce_alg=strategy_config.all_reduce_alg, num_packs=strategy_config.num_packs, tpu_address=strategy_config.tpu) num_workers = int(strategy.num_replicas_in_sync + 7) // 8 is_multi_host = (int(num_workers) >= 2) if mode == 'train': def _model_fn(params): return model_builder.build_model(params, mode=ModeKeys.TRAIN) logging.info( 'Train num_replicas_in_sync %d num_workers %d is_multi_host %s', strategy.num_replicas_in_sync, num_workers, is_multi_host) dist_executor = DetectionDistributedExecutor( strategy=strategy, params=params, model_fn=_model_fn, loss_fn=model_builder.build_loss_fn, is_multi_host=is_multi_host, predict_post_process_fn=model_builder.post_processing, trainable_variables_filter=model_builder. make_filter_trainable_variables_fn()) if is_multi_host: train_input_fn = functools.partial( train_input_fn, batch_size=params.train.batch_size // strategy.num_replicas_in_sync) return dist_executor.train( train_input_fn=train_input_fn, model_dir=params.model_dir, iterations_per_loop=params.train.iterations_per_loop, total_steps=params.train.total_steps, init_checkpoint=model_builder.make_restore_checkpoint_fn(), custom_callbacks=callbacks, save_config=True) elif mode == 'eval' or mode == 'eval_once': def _model_fn(params): return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT) logging.info( 'Eval num_replicas_in_sync %d num_workers %d is_multi_host %s', strategy.num_replicas_in_sync, num_workers, is_multi_host) if is_multi_host: eval_input_fn = functools.partial( eval_input_fn, batch_size=params.eval.batch_size // strategy.num_replicas_in_sync) dist_executor = DetectionDistributedExecutor( strategy=strategy, params=params, model_fn=_model_fn, loss_fn=model_builder.build_loss_fn, is_multi_host=is_multi_host, predict_post_process_fn=model_builder.post_processing, trainable_variables_filter=model_builder. make_filter_trainable_variables_fn()) if mode == 'eval': results = dist_executor.evaluate_from_model_dir( model_dir=params.model_dir, eval_input_fn=eval_input_fn, eval_metric_fn=model_builder.eval_metrics, eval_timeout=params.eval.eval_timeout, min_eval_interval=params.eval.min_eval_interval, total_steps=params.train.total_steps) else: # Run evaluation once for a single checkpoint. if not checkpoint_path: raise ValueError('checkpoint_path cannot be empty.') if tf.io.gfile.isdir(checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) summary_writer = executor.SummaryWriter(params.model_dir, 'eval') results, _ = dist_executor.evaluate_checkpoint( checkpoint_path=checkpoint_path, eval_input_fn=eval_input_fn, eval_metric_fn=model_builder.eval_metrics, summary_writer=summary_writer) for k, v in results.items(): logging.info('Final eval metric %s: %f', k, v) return results else: raise ValueError('Mode not found: %s.' % mode)
def run_executor(params, train_input_fn=None, eval_input_fn=None, callbacks=None): """Runs Retinanet model on distribution strategy defined by the user.""" model_builder = model_factory.model_generator(params) if FLAGS.mode == 'train': def _model_fn(params): return model_builder.build_model(params, mode=ModeKeys.TRAIN) builder = executor.ExecutorBuilder( strategy_type=params.strategy_type, strategy_config=params.strategy_config) num_workers = int(builder.strategy.num_replicas_in_sync + 7) // 8 is_multi_host = (int(num_workers) >= 2) logging.info( 'Train num_replicas_in_sync %d num_workers %d is_multi_host %s', builder.strategy.num_replicas_in_sync, num_workers, is_multi_host) if is_multi_host: train_input_fn = functools.partial( train_input_fn, batch_size=params.train.batch_size // builder.strategy.num_replicas_in_sync) dist_executor = builder.build_executor( class_ctor=DetectionDistributedExecutor, params=params, is_multi_host=is_multi_host, model_fn=_model_fn, loss_fn=model_builder.build_loss_fn, predict_post_process_fn=model_builder.post_processing, trainable_variables_filter=model_builder .make_filter_trainable_variables_fn()) return dist_executor.train( train_input_fn=train_input_fn, model_dir=params.model_dir, iterations_per_loop=params.train.iterations_per_loop, total_steps=params.train.total_steps, init_checkpoint=model_builder.make_restore_checkpoint_fn(), custom_callbacks=callbacks, save_config=True) elif FLAGS.mode == 'eval' or FLAGS.mode == 'eval_once': def _model_fn(params): return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT) builder = executor.ExecutorBuilder( strategy_type=params.strategy_type, strategy_config=params.strategy_config) num_workers = int(builder.strategy.num_replicas_in_sync + 7) // 8 is_multi_host = (int(num_workers) >= 2) if is_multi_host: eval_input_fn = functools.partial( eval_input_fn, batch_size=params.eval.batch_size // builder.strategy.num_replicas_in_sync) logging.info('Eval num_replicas_in_sync %d num_workers %d is_multi_host %s', builder.strategy.num_replicas_in_sync, num_workers, is_multi_host) dist_executor = builder.build_executor( class_ctor=DetectionDistributedExecutor, params=params, is_multi_host=is_multi_host, model_fn=_model_fn, loss_fn=model_builder.build_loss_fn, predict_post_process_fn=model_builder.post_processing, trainable_variables_filter=model_builder .make_filter_trainable_variables_fn()) if FLAGS.mode == 'eval': results = dist_executor.evaluate_from_model_dir( model_dir=params.model_dir, eval_input_fn=eval_input_fn, eval_metric_fn=model_builder.eval_metrics, eval_timeout=params.eval.eval_timeout, min_eval_interval=params.eval.min_eval_interval, total_steps=params.train.total_steps) else: # Run evaluation once for a single checkpoint. if not FLAGS.checkpoint_path: raise ValueError('FLAGS.checkpoint_path cannot be empty.') checkpoint_path = FLAGS.checkpoint_path if tf.io.gfile.isdir(checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) summary_writer = executor.SummaryWriter(params.model_dir, 'eval') results, _ = dist_executor.evaluate_checkpoint( checkpoint_path=checkpoint_path, eval_input_fn=eval_input_fn, eval_metric_fn=model_builder.eval_metrics, summary_writer=summary_writer) for k, v in results.items(): logging.info('Final eval metric %s: %f', k, v) return results else: raise ValueError('Mode not found: %s.' % FLAGS.mode)