Exemplo n.º 1
0
def run_imagenet(flags_obj):
    """Run ResNet ImageNet training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.

  Returns:
    Dict of results of the run.  Contains the keys `eval_results` and
      `train_hooks`. `eval_results` contains accuracy (top_1) and
      accuracy_top_5. `train_hooks` is a list the instances of hooks used during
      training.
  """
    if flags_obj.use_fs:
        replace_relu_with_fs()
    input_function = (flags_obj.use_synthetic_data and get_synth_input_fn(
        flags_core.get_tf_dtype(flags_obj)) or input_fn)

    result = resnet_run_loop.resnet_main(
        flags_obj,
        imagenet_model_fn,
        input_function,
        DATASET_NAME,
        shape=[DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE, NUM_CHANNELS])

    return result
Exemplo n.º 2
0
def run_cifar(flags_obj):
    """Run ResNet CIFAR-10 training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.

  Returns:
    Dictionary of results. Including final accuracy.
  """
    if flags_obj.use_fs:
        replace_relu_with_fs()

    print_spikes = flags_obj.print_spikes

    if flags_obj.image_bytes_as_serving_input:
        tf.compat.v1.logging.fatal(
            '--image_bytes_as_serving_input cannot be set to True for CIFAR. '
            'This flag is only applicable to ImageNet.')
        return

    input_function = (flags_obj.use_synthetic_data and get_synth_input_fn(
        flags_core.get_tf_dtype(flags_obj)) or input_fn)
    result = resnet_run_loop.resnet_main(flags_obj,
                                         cifar10_model_fn,
                                         input_function,
                                         DATASET_NAME,
                                         shape=[HEIGHT, WIDTH, NUM_CHANNELS])

    return result