예제 #1
0
def main(argv):
  parser = resnet_run_loop.ResnetArgParser(
      resnet_size_choices=[18, 34, 50, 101, 152, 200])

  parser.set_defaults(
      train_epochs=90,
      version=1
  )

  flags = parser.parse_args(args=argv[2:])

  seed = int(argv[1])
  print('Setting random seed = ', seed)
  print('special seeding')
  mlperf_log.resnet_print(key=mlperf_log.RUN_SET_RANDOM_SEED, value=seed)
  random.seed(seed)
  tf.set_random_seed(seed)
  numpy.random.seed(seed)

  mlperf_log.resnet_print(key=mlperf_log.PREPROC_NUM_TRAIN_EXAMPLES,
                          value=_NUM_IMAGES['train'])
  mlperf_log.resnet_print(key=mlperf_log.PREPROC_NUM_EVAL_EXAMPLES,
                          value=_NUM_IMAGES['validation'])
  input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn

  resnet_run_loop.resnet_main(seed,
      flags, imagenet_model_fn, input_function,
      shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS])
예제 #2
0
def main(argv):
    parser = resnet_run_loop.ResnetArgParser(
        resnet_size_choices=[18, 34, 50, 101, 152, 200])
    flags = parser.parse_args(args=argv[1:])

    input_function = flags.use_synthetic_data and get_synth_input_fn(
    ) or input_fn
    resnet_run_loop.resnet_main(flags, imagenet_model_fn, input_function)
예제 #3
0
def main(argv):
  parser = resnet_run_loop.ResnetArgParser()
  # Set defaults that are reasonable for this model.
  parser.set_defaults(data_dir='/tmp/cifar10_data',
                      model_dir='/tmp/cifar10_model',
                      resnet_size=32,
                      train_epochs=250,
                      epochs_between_evals=10,
                      batch_size=128)

  flags = parser.parse_args(args=argv[1:])

  input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn
  resnet_run_loop.resnet_main(flags, cifar10_model_fn, input_function)
예제 #4
0
def main(argv):
    parser = resnet_run_loop.ResnetArgParser(
        resnet_size_choices=[18, 34, 50, 101, 152, 200])

    parser.set_defaults(train_epochs=100)

    flags = parser.parse_args(args=argv[1:])

    input_function = flags.use_synthetic_data and get_synth_input_fn(
    ) or input_fn

    resnet_run_loop.resnet_main(
        flags,
        imagenet_model_fn,
        input_function,
        shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS])
예제 #5
0
def main(argv):
    parser = resnet_run_loop.ResnetArgParser(
        resnet_size_choices=[18, 34, 50, 101, 152, 200])

    parser.set_defaults(train_epochs=100)

    flags = parser.parse_args(args=argv[1:])
    #procid = os.environ['SLURM_PROCID']
    #procid = os.environ['ALPS_APP_PE']
    #flags.model_dir = flags.model_dir + '/' + procid
    #flags.benchmark_log_dir = flags.benchmark_log_dir + '/' + procid
    #flags.export_dir = flags.export_dir + '/' + procid
    input_function = flags.use_synthetic_data and get_synth_input_fn(
    ) or input_fn

    resnet_run_loop.resnet_main(
        flags,
        imagenet_model_fn,
        input_function,
        _NUM_IMAGES['train'],
        _NUM_IMAGES['validation'],
        shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS])
예제 #6
0
def main(argv):
    parser = resnet_run_loop.ResnetArgParser(
        resnet_size_choices=[18, 34, 50, 101, 152, 200])

    parser.set_defaults(train_epochs=100)

    flags = parser.parse_args(args=argv[2:])

    seed = int(argv[1])
    print('Setting random seed = ', seed)
    print('special seeding')
    random.seed(seed)
    tf.set_random_seed(seed)
    numpy.random.seed(seed)

    input_function = flags.use_synthetic_data and get_synth_input_fn(
    ) or input_fn

    resnet_run_loop.resnet_main(
        seed,
        flags,
        imagenet_model_fn,
        input_function,
        shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS])
예제 #7
0
        boundary_epochs=[30, 60, 80, 90],
        decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])

    return resnet_run_loop.resnet_model_fn(features,
                                           labels,
                                           mode,
                                           ImagenetModel,
                                           resnet_size=params['resnet_size'],
                                           weight_decay=1e-4,
                                           learning_rate_fn=learning_rate_fn,
                                           momentum=0.9,
                                           data_format=params['data_format'],
                                           version=params['version'],
                                           loss_filter_fn=None,
                                           multi_gpu=params['multi_gpu'])


def main(unused_argv):
    input_function = FLAGS.use_synthetic_data and get_synth_input_fn(
    ) or input_fn
    resnet_run_loop.resnet_main(FLAGS, imagenet_model_fn, input_function)


if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)

    parser = resnet_run_loop.ResnetArgParser(
        resnet_size_choices=[18, 34, 50, 101, 152, 200])
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(argv=[sys.argv[0]] + unparsed)
예제 #8
0
                                           resnet_size=params['resnet_size'],
                                           weight_decay=weight_decay,
                                           learning_rate_fn=learning_rate_fn,
                                           momentum=0.9,
                                           data_format=params['data_format'],
                                           version=params['version'],
                                           loss_filter_fn=loss_filter_fn,
                                           multi_gpu=params['multi_gpu'])


def main(unused_argv):
    input_function = FLAGS.use_synthetic_data and get_synth_input_fn(
    ) or input_fn
    resnet_run_loop.resnet_main(FLAGS, cifar10_model_fn, input_function)


if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)

    parser = resnet_run_loop.ResnetArgParser()
    # Set defaults that are reasonable for this model.
    parser.set_defaults(data_dir='/tmp/cifar10_data',
                        model_dir='/tmp/cifar10_model',
                        resnet_size=32,
                        train_epochs=250,
                        epochs_per_eval=10,
                        batch_size=128)

    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(argv=[sys.argv[0]] + unparsed)