示例#1
0
def main(argv):
    parser = resnet_run_loop.ResnetArgParser(
        resnet_size_choices=[18, 34, 50, 101, 152, 200])

    parser.set_defaults(train_epochs=100,
                        data_dir='./data',
                        model_dir='./model')

    flags = parser.parse_args(args=argv[1:])

    train_path = os.path.join(flags.data_dir, 'train.tfrecord')
    test_path = os.path.join(flags.data_dir, 'test.tfrecord')
    _NUM_IMAGES['train'] = sum(
        1 for _ in tf.python_io.tf_record_iterator(train_path))
    _NUM_IMAGES['test'] = sum(
        1 for _ in tf.python_io.tf_record_iterator(test_path))

    # batch_size=32
    # data_dir = './data',
    # model_dir = './model'
    # resnet_size = 50
    # version = 2
    # train_epochs = 100
    # epochs_between_evals = 1
    # max_train_steps = None

    resnet_run_loop.resnet_main(flags, model_fn, input_fn)
示例#2
0
def main(argv):
    parser = resnet_run_loop.ResnetArgParser(
        resnet_size_choices=[18, 26, 34, 50, 101, 152, 200])

    parser.set_defaults(train_epochs=90, version=1)

    flags = parser.parse_args(args=argv[2:])

    if flags.oss_load:
        auth = oss2.Auth(_ACCESS_ID, _ACCESS_KEY)
        bucket = oss2.Bucket(auth, _HOST, _BUCKET)

    seed = int(argv[1])
    print('Setting random seed = ', seed)
    print('special seeding')
    mlperf_log.resnet_print(key=mlperf_log.RUN_SET_RANDOM_SEED, value=seed)
    random.seed(seed)
    tf.set_random_seed(seed)
    np.random.seed(seed)

    mlperf_log.resnet_print(key=mlperf_log.PREPROC_NUM_TRAIN_EXAMPLES,
                            value=_NUM_IMAGES['train'])
    mlperf_log.resnet_print(key=mlperf_log.PREPROC_NUM_EVAL_EXAMPLES,
                            value=_NUM_IMAGES['validation'])
    input_function = input_fn

    resnet_run_loop.resnet_main(
        seed,
        flags,
        imagenet_model_fn,
        input_function,
        shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS])
示例#3
0
def main(argv):
    parser = resnet_run_loop.ResnetArgParser(
        resnet_size_choices=[18, 34, 50, 101, 152, 200])

    parser.set_defaults(train_epochs=100, data_dir='./data')

    flags = parser.parse_args(args=argv[1:])

    flags.model_dir = './no-lmk-model' if flags.no_lmk else './lmk-model'

    _NUM_IMAGES['train'] = sum(1 for _ in tf.python_io.tf_record_iterator(
        get_filenames(True, flags.data_dir)[0]))
    _NUM_IMAGES['test'] = sum(1 for _ in tf.python_io.tf_record_iterator(
        get_filenames(False, flags.data_dir)[0]))

    # batch_size=32
    # no-lmk = False
    # data_dir = './data',
    # model_dir = './lmk-model'
    # resnet_size = 50
    # version = 2
    # train_epochs = 100
    # epochs_between_evals = 1
    # max_train_steps = None

    resnet_run_loop.resnet_main(flags, model_fn, input_fn)
def main(argv):
    parser = resnet_run_loop.ResnetArgParser(
        resnet_size_choices=[18, 34, 50, 101, 152, 200])

    parser.set_defaults(train_epochs=100,
                        data_dir='../dataset',
                        model_dir='./model')

    flags = parser.parse_args(args=argv[1:])

    train_path = os.path.join(flags.data_dir, 'train.tfrecord')
    validation_path = os.path.join(flags.data_dir, 'validation.tfrecord')
    predict_path = os.path.join(flags.data_dir, 'predict.tfrecord')
    _NUM_IMAGES['train'] = sum(
        1 for _ in tf.python_io.tf_record_iterator(train_path))
    _NUM_IMAGES['validation'] = sum(
        1 for _ in tf.python_io.tf_record_iterator(validation_path))
    _NUM_IMAGES['predict'] = sum(
        1 for _ in tf.python_io.tf_record_iterator(predict_path))

    input_function = flags.use_synthetic_data and get_synth_input_fn(
    ) or input_fn

    resnet_run_loop.resnet_main(
        flags,
        model_fn,
        input_function,
        shape=[_IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS])
示例#5
0
def run_cifar(flags_obj):
    input_function = input_fn

    resnet_run_loop.resnet_main(flags_obj,
                                cifar_10_model_fn,
                                input_function,
                                DATASET_NAME,
                                shape=[_HEIGHT, _WIDTH, _NUM_CLASSES])
示例#6
0
def run_imagenet(flags_obj):
  """Run ResNet ImageNet training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.
  """
  input_function = input_fn

  resnet_run_loop.resnet_main(
      flags_obj, imagenet_model_fn, input_function, DATASET_NAME,
      shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS])
def main(argv):
    parser = resnet_run_loop.ResnetArgParser(
        resnet_size_choices=[18, 34, 50, 101, 152, 200])

    parser.set_defaults(train_epochs=100)

    flags = parser.parse_args(args=argv[1:])

    input_function = flags.use_synthetic_data and get_synth_input_fn(
    ) or input_fn
    resnet_run_loop.resnet_main(flags, imagenet_model_fn, input_function)
示例#8
0
def run_imagenet(flags_obj):
	"""Run ResNet ImageNet training and eval loop.

	Args:
		flags_obj: An object containing parsed flag values.
	"""
	input_function = (flags_obj.use_synthetic_data and
										get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or
										input_fn)

	resnet_run_loop.resnet_main(
			flags_obj, imagenet_model_fn, input_function, DATASET_NAME,
			shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS])
示例#9
0
def main(argv):
  parser = resnet_run_loop.ResnetArgParser()
  # Set defaults that are reasonable for this model.
  parser.set_defaults(data_dir='/tmp/cifar10_data',
                      model_dir='/tmp/cifar10_model',
                      resnet_size=32,
                      train_epochs=250,
                      epochs_between_evals=10,
                      batch_size=128,
                      select_device='NGRAPH')

  flags = parser.parse_args(args=argv[1:])
  input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn
  resnet_run_loop.resnet_main(flags, cifar10_model_fn, input_function)
def run_imagenet(flags_obj):
    """Run ResNet ImageNet training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.

  Returns:
    Dict of results of the run.  Contains the keys `eval_results` and
      `train_hooks`. `eval_results` contains accuracy (top_1) and
      accuracy_top_5. `train_hooks` is a list the instances of hooks used during
      training.
  """
    # 选择输入数据还是合成数据,get_synth_input_fn是随机合成的数据
    # input_fn是输入数据
    input_function = (flags_obj.use_synthetic_data and get_synth_input_fn(
        flags_core.get_tf_dtype(flags_obj)) or input_fn)

    result = resnet_run_loop.resnet_main(
        flags_obj,
        imagenet_model_fn,
        input_function,
        DATASET_NAME,
        shape=[DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE, NUM_CHANNELS])

    return result
def run_cifar(flags_obj, config, conf_matrix):
    """Run ResNet CIFAR-10 training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.
  """
    if config._mode == 'predict':
        input_function = input_fn_predict
    else:
        input_function = (flags_obj.use_synthetic_data
                          and get_synth_input_fn() or input_fn)
    resnet_run_loop.resnet_main(flags_obj,
                                config,
                                conf_matrix,
                                cifar10_model_fn,
                                input_function,
                                DATASET_NAME,
                                shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])
示例#12
0
def main(argv):
    parser = resnet_run_loop.ResnetArgParser()
    # Set defaults that are reasonable for this model.
    parser.set_defaults(data_dir='cifar10_data',
                        model_dir='cifar10_model',
                        resnet_size=32,
                        train_epochs=250,
                        epochs_between_evals=10,
                        batch_size=128)

    flags = parser.parse_args(args=argv[1:])
    import pdb
    pdb.set_trace()
    input_function = flags.use_synthetic_data and get_synth_input_fn(
    ) or input_fn

    resnet_run_loop.resnet_main(flags,
                                cifar10_model_fn,
                                input_function,
                                shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])
def run_cifar(flags_obj):
  """Run ResNet CIFAR-10 training and eval loop.
  Args:
    flags_obj: An object containing parsed flag values.
  """
  input_function = (flags_obj.use_synthetic_data and get_synth_input_fn()
                    or input_fn)
  eval_accuracy = resnet_run_loop.resnet_main( # Xinyi modified
      flags_obj, cifar10_model_fn, input_function, DATASET_NAME,
      shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])
  
  return eval_accuracy  # Xinyi modified
示例#14
0
def run_cifar(flags_obj):
    """Run ResNet CIFAR-10 training and eval loop.

    Args:
      flags_obj: An object containing parsed flag values.

    Returns:
      Dictionary of results. Including final accuracy.
    """
    if flags_obj.image_bytes_as_serving_input:
        tf.compat.v1.logging.fatal(
            '--image_bytes_as_serving_input cannot be set to True for CIFAR. '
            'This flag is only applicable to ImageNet.')
        return

    input_function = (flags_obj.use_synthetic_data and get_synth_input_fn(
        flags_core.get_tf_dtype(flags_obj)) or input_fn)
    result = resnet_run_loop.resnet_main(flags_obj,
                                         cifar10_model_fn,
                                         input_function,
                                         DATASET_NAME,
                                         shape=[HEIGHT, WIDTH, NUM_CHANNELS])

    return result