コード例 #1
0
def define_cifar_flags():
    keras_common.define_keras_flags(dynamic_loss_scale=False)

    flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin',
                            model_dir='/tmp/cifar10_model',
                            train_epochs=182,
                            epochs_between_evals=10,
                            batch_size=128)
コード例 #2
0
 def _setup(self):
     """Setups up and resets flags before each test."""
     tf.logging.set_verbosity(tf.logging.DEBUG)
     if KerasCifar10BenchmarkTests.local_flags is None:
         keras_common.define_keras_flags()
         cifar_main.define_cifar_flags()
         # Loads flags to get defaults to then override. List cannot be empty.
         flags.FLAGS(['foo'])
         saved_flag_values = flagsaver.save_flag_values()
         KerasCifar10BenchmarkTests.local_flags = saved_flag_values
         return
     flagsaver.restore_flag_values(KerasCifar10BenchmarkTests.local_flags)
コード例 #3
0
ファイル: keras_imagenet_main.py プロジェクト: xjm0625/models
def define_imagenet_keras_flags():
  keras_common.define_keras_flags()
  flags_core.set_defaults(train_epochs=90)
コード例 #4
0
    num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        num_eval_steps = None
        validation_data = None

    model.fit(train_input_dataset,
              epochs=train_epochs,
              steps_per_epoch=train_steps,
              callbacks=[time_callback, lr_callback, tensorboard_callback],
              validation_steps=num_eval_steps,
              validation_data=validation_data,
              verbose=1)

    if not flags_obj.skip_eval:
        model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=1)


def main(_):
    with logger.benchmark_context(flags.FLAGS):
        run(flags.FLAGS)


if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)
    imagenet_main.define_imagenet_flags()
    keras_common.define_keras_flags()
    absl_app.run(main)
コード例 #5
0
                      steps_per_epoch=train_steps,
                      callbacks=[
                          time_callback,
                          lr_callback,
                          tensorboard_callback
                      ],
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=2)

  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)
  stats = keras_common.build_stats(history, eval_output, time_callback)
  return stats


def main(_):
  with logger.benchmark_context(flags.FLAGS):
    return run(flags.FLAGS)


if __name__ == '__main__':
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
  imagenet_main.define_imagenet_flags()
  keras_common.define_keras_flags()
  absl_app.run(main)
コード例 #6
0
def define_imagenet_keras_flags():
    imagenet_main.define_imagenet_flags(dynamic_loss_scale=True,
                                        enable_xla=True)
    keras_common.define_keras_flags()
コード例 #7
0
ファイル: keras_cifar_test.py プロジェクト: rder96/models
 def setUpClass(cls):  # pylint: disable=invalid-name
   super(KerasCifarTest, cls).setUpClass()
   cifar10_main.define_cifar_flags()
   keras_common.define_keras_flags()
コード例 #8
0
 def setUpClass(cls):  # pylint: disable=invalid-name
     super(KerasCifarTest, cls).setUpClass()
     cifar10_main.define_cifar_flags()
     keras_common.define_keras_flags()
コード例 #9
0
 def setUpClass(cls):  # pylint: disable=invalid-name
     super(KerasImagenetTest, cls).setUpClass()
     imagenet_main.define_imagenet_flags()
     keras_common.define_keras_flags()
コード例 #10
0
def define_imagenet_keras_flags():
    keras_common.define_keras_flags()
    flags_core.set_defaults(train_epochs=90)
    flags.adopt_module_key_flags(keras_common)
コード例 #11
0
ファイル: imagenet_test.py プロジェクト: Exscotticus/models
 def setUpClass(cls):  # pylint: disable=invalid-name
   super(BaseTest, cls).setUpClass()
   imagenet_main.define_imagenet_flags()
   keras_common.define_keras_flags()
コード例 #12
0
 def setUpClass(cls):  # pylint: disable=invalid-name
     super(CtlImagenetTest, cls).setUpClass()
     keras_common.define_keras_flags()
     ctl_common.define_ctl_flags()