def main(unused_argv):
    global _cfg

    if layers_resnet.get_model(FLAGS.model) is None:
        raise RuntimeError('--model must be one of [' +
                           ', '.join(layers_resnet.get_available_models()) +
                           ']')
    if FLAGS.device not in ['CPU', 'GPU', 'TPU']:
        raise RuntimeError('--device must be one of [CPU, GPU, TPU]')
    if FLAGS.input_layout not in ['NCHW', 'NHWC']:
        raise RuntimeError('--input_layout must be one of [NCHW, NHWC]')

    if FLAGS.winograd_nonfused:
        os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
    else:
        os.environ.pop('TF_ENABLE_WINOGRAD_NONFUSED', None)
    _cfg = ResnetConfig()
    setup_learning_rate_schedule()

    session_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=FLAGS.log_device_placement)
    if FLAGS.device == 'GPU':
        session_config.gpu_options.allow_growth = True
    config = tpu_config.RunConfig(
        save_checkpoints_secs=FLAGS.save_checkpoints_secs or None,
        save_summary_steps=FLAGS.save_summary_steps,
        log_step_count_steps=FLAGS.log_step_count_steps,
        master=FLAGS.master,
        model_dir=FLAGS.model_dir,
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_shards,
            per_host_input_for_training=FLAGS.per_host_input_pipeline),
        session_config=session_config)

    if FLAGS.device == 'GPU' and FLAGS.num_shards > 1:
        run_on_gpu(config)
        return

    resnet_classifier = tpu_estimator.TPUEstimator(
        model_fn=get_model_fn(),
        use_tpu=FLAGS.device == 'TPU',
        config=config,
        train_batch_size=FLAGS.batch_size,
        eval_batch_size=FLAGS.batch_size,
        batch_axis=(get_image_batch_axis(), 0))

    pipeline_input_fn = get_input_pipeline_fn()

    def eval_input(params=None):
        return pipeline_input_fn(params=params,
                                 eval_batch_size=FLAGS.batch_size)

    model_conductor.conduct(resnet_classifier,
                            pipeline_input_fn,
                            eval_input,
                            get_train_steps(),
                            FLAGS.epochs_per_train * _cfg.batches_per_epoch,
                            FLAGS.eval_steps,
                            train_hooks=get_train_hooks(),
                            target_accuracy=FLAGS.target_accuracy)
Exemple #2
0
def multigpu_run(config,
                 train_inputfn,
                 eval_inputfn,
                 modelfn,
                 num_gpus,
                 batch_size,
                 shard_axis,
                 weight_decay,
                 momentum,
                 learning_rate,
                 train_steps,
                 eval_steps,
                 steps_per_train,
                 target_accuracy=None,
                 train_hooks=None):
    """Trains and evaluates a model on GPU.

  Args:
    config: The RunConfig object used to configure the Estimator used by the
        GPU model execution.
    train_inputfn: The input function used for training, which returns a tuple
        with the features and one-hot labels matching the features.
    eval_inputfn: The input function used for evaluation, which returns a tuple
        with the features and one-hot labels matching the features.
    modelfn: The core model function which builds the model computation graph.
    num_gpus: The number of GPU devices to shard the computation onto.
    batch_size: The global batch size.
    shard_axis: A tuple containing the tensor axis which should be used to shard
        inputs and labels respectively, across GPU devices.
    weight_decay: Weight regularization strength, a float.
    momentum: Momentum for MomentumOptimizer.
    learning_rate: Function object which build the learning rate graph.
    train_steps: The number of steps to be executed for training.
    eval_steps: The number of steps to be executed for evaluation.
    steps_per_train: How many training steps should be executed before
        running evaluation steps.
    target_accuracy: If specified, a given accuracy target at which to stop
        the training.
    train_hooks: Optional hooks for the training operation.
  """
    def wrapped_modelfn(features, labels, mode):
        return multigpu_model_fn(features, labels, mode, modelfn, num_gpus,
                                 weight_decay, momentum, learning_rate)

    def train_input_function():
        image_batch, label_batch = train_inputfn()
        return split_batch_input(image_batch, label_batch, batch_size,
                                 num_gpus, shard_axis)

    def eval_input_function():
        image_batch, label_batch = eval_inputfn()
        return split_batch_input(image_batch, label_batch, batch_size,
                                 num_gpus, shard_axis)

    # Hooks that add extra logging that is useful to see the loss more often in
    # the console as well as examples per second.
    examples_sec_hook = ExamplesPerSecondHook(batch_size, every_n_steps=10)
    hooks = [examples_sec_hook]
    if train_hooks:
        hooks.extend(train_hooks)

    classifier = tf.estimator.Estimator(model_fn=wrapped_modelfn,
                                        config=config)

    model_conductor.conduct(classifier,
                            train_input_function,
                            eval_input_function,
                            train_steps,
                            steps_per_train,
                            eval_steps,
                            train_hooks=hooks,
                            target_accuracy=target_accuracy)