def run_executor(params, train_input_fn=None, eval_input_fn=None, callbacks=None): """Runs Retinanet model on distribution strategy defined by the user.""" if params.architecture.use_bfloat16: policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_bfloat16') tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) model_builder = model_factory.model_generator(params) if FLAGS.mode == 'train': def _model_fn(params): return model_builder.build_model(params, mode=ModeKeys.TRAIN) builder = executor.ExecutorBuilder( strategy_type=params.strategy_type, strategy_config=params.strategy_config) num_workers = int(builder.strategy.num_replicas_in_sync + 7) // 8 is_multi_host = (int(num_workers) >= 2) logging.info( 'Train num_replicas_in_sync %d num_workers %d is_multi_host %s', builder.strategy.num_replicas_in_sync, num_workers, is_multi_host) if is_multi_host: train_input_fn = functools.partial( train_input_fn, batch_size=params.train.batch_size // builder.strategy.num_replicas_in_sync) dist_executor = builder.build_executor( class_ctor=DetectionDistributedExecutor, params=params, is_multi_host=is_multi_host, model_fn=_model_fn, loss_fn=model_builder.build_loss_fn, predict_post_process_fn=model_builder.post_processing, trainable_variables_filter=model_builder. make_filter_trainable_variables_fn()) return dist_executor.train( train_input_fn=train_input_fn, model_dir=params.model_dir, iterations_per_loop=params.train.iterations_per_loop, total_steps=params.train.total_steps, init_checkpoint=model_builder.make_restore_checkpoint_fn(), custom_callbacks=callbacks, save_config=True) elif FLAGS.mode == 'eval' or FLAGS.mode == 'eval_once': def _model_fn(params): return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT) builder = executor.ExecutorBuilder( strategy_type=params.strategy_type, strategy_config=params.strategy_config) num_workers = int(builder.strategy.num_replicas_in_sync + 7) // 8 is_multi_host = (int(num_workers) >= 2) if is_multi_host: eval_input_fn = functools.partial( eval_input_fn, batch_size=params.eval.batch_size // builder.strategy.num_replicas_in_sync) logging.info( 'Eval num_replicas_in_sync %d num_workers %d is_multi_host %s', builder.strategy.num_replicas_in_sync, num_workers, is_multi_host) dist_executor = builder.build_executor( class_ctor=DetectionDistributedExecutor, params=params, is_multi_host=is_multi_host, model_fn=_model_fn, loss_fn=model_builder.build_loss_fn, predict_post_process_fn=model_builder.post_processing, trainable_variables_filter=model_builder. make_filter_trainable_variables_fn()) if FLAGS.mode == 'eval': results = dist_executor.evaluate_from_model_dir( model_dir=params.model_dir, eval_input_fn=eval_input_fn, eval_metric_fn=model_builder.eval_metrics, eval_timeout=params.eval.eval_timeout, min_eval_interval=params.eval.min_eval_interval, total_steps=params.train.total_steps) else: # Run evaluation once for a single checkpoint. if not FLAGS.checkpoint_path: raise ValueError('FLAGS.checkpoint_path cannot be empty.') checkpoint_path = FLAGS.checkpoint_path if tf.io.gfile.isdir(checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) summary_writer = executor.SummaryWriter(params.model_dir, 'eval') results, _ = dist_executor.evaluate_checkpoint( checkpoint_path=checkpoint_path, eval_input_fn=eval_input_fn, eval_metric_fn=model_builder.eval_metrics, summary_writer=summary_writer) for k, v in results.items(): logging.info('Final eval metric %s: %f', k, v) return results else: raise ValueError('Mode not found: %s.' % FLAGS.mode)
def run_executor(params, train_input_fn=None, eval_input_fn=None): """Runs Retinanet model on distribution strategy defined by the user.""" model_builder = model_factory.model_generator(params) if FLAGS.mode == 'train': def _model_fn(params): return model_builder.build_model(params, mode=ModeKeys.TRAIN) builder = executor.ExecutorBuilder( strategy_type=params.strategy_type, strategy_config=params.strategy_config) num_workers = (builder.strategy.num_replicas_in_sync + 7) / 8 is_multi_host = (num_workers > 1) if is_multi_host: train_input_fn = functools.partial( train_input_fn, batch_size=params.train.batch_size // builder.strategy.num_replicas_in_sync) dist_executor = builder.build_executor( class_ctor=DetectionDistributedExecutor, params=params, is_multi_host=is_multi_host, model_fn=_model_fn, loss_fn=model_builder.build_loss_fn, predict_post_process_fn=model_builder.post_processing, trainable_variables_filter=model_builder. make_filter_trainable_variables_fn()) return dist_executor.train( train_input_fn=train_input_fn, model_dir=params.model_dir, iterations_per_loop=params.train.iterations_per_loop, total_steps=params.train.total_steps, init_checkpoint=model_builder.make_restore_checkpoint_fn(), save_config=True) elif FLAGS.mode == 'eval': def _model_fn(params): return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT) builder = executor.ExecutorBuilder( strategy_type=params.strategy_type, strategy_config=params.strategy_config) dist_executor = builder.build_executor( class_ctor=DetectionDistributedExecutor, params=params, model_fn=_model_fn, loss_fn=model_builder.build_loss_fn, predict_post_process_fn=model_builder.post_processing, trainable_variables_filter=model_builder. make_filter_trainable_variables_fn()) results = dist_executor.evaluate_from_model_dir( model_dir=params.model_dir, eval_input_fn=eval_input_fn, eval_metric_fn=model_builder.eval_metrics, eval_timeout=params.eval.eval_timeout, min_eval_interval=params.eval.min_eval_interval, total_steps=params.train.total_steps) for k, v in results.items(): logging.info('Final eval metric %s: %f', k, v) return results else: raise ValueError('Mode not found: %s.' % FLAGS.mode)