コード例 #1
0
def run_executor(params,
                 train_input_fn=None,
                 eval_input_fn=None,
                 callbacks=None):
    """Runs Retinanet model on distribution strategy defined by the user."""

    if params.architecture.use_bfloat16:
        policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
            'mixed_bfloat16')
        tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

    model_builder = model_factory.model_generator(params)

    if FLAGS.mode == 'train':

        def _model_fn(params):
            return model_builder.build_model(params, mode=ModeKeys.TRAIN)

        builder = executor.ExecutorBuilder(
            strategy_type=params.strategy_type,
            strategy_config=params.strategy_config)
        num_workers = int(builder.strategy.num_replicas_in_sync + 7) // 8
        is_multi_host = (int(num_workers) >= 2)
        logging.info(
            'Train num_replicas_in_sync %d num_workers %d is_multi_host %s',
            builder.strategy.num_replicas_in_sync, num_workers, is_multi_host)
        if is_multi_host:
            train_input_fn = functools.partial(
                train_input_fn,
                batch_size=params.train.batch_size //
                builder.strategy.num_replicas_in_sync)

        dist_executor = builder.build_executor(
            class_ctor=DetectionDistributedExecutor,
            params=params,
            is_multi_host=is_multi_host,
            model_fn=_model_fn,
            loss_fn=model_builder.build_loss_fn,
            predict_post_process_fn=model_builder.post_processing,
            trainable_variables_filter=model_builder.
            make_filter_trainable_variables_fn())

        return dist_executor.train(
            train_input_fn=train_input_fn,
            model_dir=params.model_dir,
            iterations_per_loop=params.train.iterations_per_loop,
            total_steps=params.train.total_steps,
            init_checkpoint=model_builder.make_restore_checkpoint_fn(),
            custom_callbacks=callbacks,
            save_config=True)
    elif FLAGS.mode == 'eval' or FLAGS.mode == 'eval_once':

        def _model_fn(params):
            return model_builder.build_model(params,
                                             mode=ModeKeys.PREDICT_WITH_GT)

        builder = executor.ExecutorBuilder(
            strategy_type=params.strategy_type,
            strategy_config=params.strategy_config)
        num_workers = int(builder.strategy.num_replicas_in_sync + 7) // 8
        is_multi_host = (int(num_workers) >= 2)
        if is_multi_host:
            eval_input_fn = functools.partial(
                eval_input_fn,
                batch_size=params.eval.batch_size //
                builder.strategy.num_replicas_in_sync)
        logging.info(
            'Eval num_replicas_in_sync %d num_workers %d is_multi_host %s',
            builder.strategy.num_replicas_in_sync, num_workers, is_multi_host)
        dist_executor = builder.build_executor(
            class_ctor=DetectionDistributedExecutor,
            params=params,
            is_multi_host=is_multi_host,
            model_fn=_model_fn,
            loss_fn=model_builder.build_loss_fn,
            predict_post_process_fn=model_builder.post_processing,
            trainable_variables_filter=model_builder.
            make_filter_trainable_variables_fn())

        if FLAGS.mode == 'eval':
            results = dist_executor.evaluate_from_model_dir(
                model_dir=params.model_dir,
                eval_input_fn=eval_input_fn,
                eval_metric_fn=model_builder.eval_metrics,
                eval_timeout=params.eval.eval_timeout,
                min_eval_interval=params.eval.min_eval_interval,
                total_steps=params.train.total_steps)
        else:
            # Run evaluation once for a single checkpoint.
            if not FLAGS.checkpoint_path:
                raise ValueError('FLAGS.checkpoint_path cannot be empty.')
            checkpoint_path = FLAGS.checkpoint_path
            if tf.io.gfile.isdir(checkpoint_path):
                checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
            summary_writer = executor.SummaryWriter(params.model_dir, 'eval')
            results, _ = dist_executor.evaluate_checkpoint(
                checkpoint_path=checkpoint_path,
                eval_input_fn=eval_input_fn,
                eval_metric_fn=model_builder.eval_metrics,
                summary_writer=summary_writer)
        for k, v in results.items():
            logging.info('Final eval metric %s: %f', k, v)
        return results
    else:
        raise ValueError('Mode not found: %s.' % FLAGS.mode)
コード例 #2
0
def run_executor(params, train_input_fn=None, eval_input_fn=None):
    """Runs Retinanet model on distribution strategy defined by the user."""

    model_builder = model_factory.model_generator(params)

    if FLAGS.mode == 'train':

        def _model_fn(params):
            return model_builder.build_model(params, mode=ModeKeys.TRAIN)

        builder = executor.ExecutorBuilder(
            strategy_type=params.strategy_type,
            strategy_config=params.strategy_config)
        num_workers = (builder.strategy.num_replicas_in_sync + 7) / 8
        is_multi_host = (num_workers > 1)
        if is_multi_host:
            train_input_fn = functools.partial(
                train_input_fn,
                batch_size=params.train.batch_size //
                builder.strategy.num_replicas_in_sync)

        dist_executor = builder.build_executor(
            class_ctor=DetectionDistributedExecutor,
            params=params,
            is_multi_host=is_multi_host,
            model_fn=_model_fn,
            loss_fn=model_builder.build_loss_fn,
            predict_post_process_fn=model_builder.post_processing,
            trainable_variables_filter=model_builder.
            make_filter_trainable_variables_fn())

        return dist_executor.train(
            train_input_fn=train_input_fn,
            model_dir=params.model_dir,
            iterations_per_loop=params.train.iterations_per_loop,
            total_steps=params.train.total_steps,
            init_checkpoint=model_builder.make_restore_checkpoint_fn(),
            save_config=True)
    elif FLAGS.mode == 'eval':

        def _model_fn(params):
            return model_builder.build_model(params,
                                             mode=ModeKeys.PREDICT_WITH_GT)

        builder = executor.ExecutorBuilder(
            strategy_type=params.strategy_type,
            strategy_config=params.strategy_config)
        dist_executor = builder.build_executor(
            class_ctor=DetectionDistributedExecutor,
            params=params,
            model_fn=_model_fn,
            loss_fn=model_builder.build_loss_fn,
            predict_post_process_fn=model_builder.post_processing,
            trainable_variables_filter=model_builder.
            make_filter_trainable_variables_fn())

        results = dist_executor.evaluate_from_model_dir(
            model_dir=params.model_dir,
            eval_input_fn=eval_input_fn,
            eval_metric_fn=model_builder.eval_metrics,
            eval_timeout=params.eval.eval_timeout,
            min_eval_interval=params.eval.min_eval_interval,
            total_steps=params.train.total_steps)
        for k, v in results.items():
            logging.info('Final eval metric %s: %f', k, v)
        return results
    else:
        raise ValueError('Mode not found: %s.' % FLAGS.mode)