예제 #1
0
def define_imagenet_flags():
    resnet_run_loop.define_resnet_flags(
        resnet_size_choices=['18', '34', '50', '101', '152', '200'],
        dynamic_loss_scale=True,
        fp16_implementation=True)
    flags.adopt_module_key_flags(resnet_run_loop)
    flags_core.set_defaults(train_epochs=90)
예제 #2
0
def define_cifar_flags(hp, model_id, model_dir, data_dir, train_epochs, total_epochs, epoch_index): # Xinyi modified
  resnet_run_loop.define_resnet_flags()
  flags.adopt_module_key_flags(resnet_run_loop)
  
  # Xinyi add followings
  flags.DEFINE_string(
        name="optimizer", short_name="opt", default=hp['opt_case']['optimizer'],
        help=help_wrap("The name of optimizer type"))
  if hp['opt_case']['optimizer']=='Momentum' \
    or hp['opt_case']['optimizer']=='RMSProp':
      flags.DEFINE_float(
        name="momentum", short_name="mm",
        default=hp['opt_case']['momentum'],
        help=help_wrap("The momentum of Momentum SGD or RMSProp"))
  if hp['opt_case']['optimizer']=='RMSProp':
    flags.DEFINE_float(
        name="grad_decay", short_name="rmspd",
        default=hp['opt_case']['grad_decay'],
        help=help_wrap("The decay of RMSProp"))
  flags.DEFINE_float(
        name="learning_rate", short_name="lr",
        default=hp['opt_case']['lr'],
        help=help_wrap("The initial learning rate of optimizer"))
  flags.DEFINE_float(
        name="decay_rate", short_name="lrdr", default=hp['decay_rate'],
        help=help_wrap("The base term of learning rate decay function"))
  flags.DEFINE_integer(
        name="decay_steps", short_name="lrds", default=hp['decay_steps'],
        help=help_wrap("The power term of learning rate decay function"
            "This value is in percentage of train_epochs"
            "Zero value means turnning off decay"))
  flags.DEFINE_string(
        name="initializer", short_name="initn", default=hp['initializer'],
        help=help_wrap("The name of initialization method"
            "None value means glorot_uniform_initializer"))
  flags.DEFINE_string(
        name="regularizer", short_name="regn", default=hp['regularizer'],
        help=help_wrap("The name of regularization method"
            "None value means turnning off weight decay"))
  flags.DEFINE_float(
        name="weight_decay", short_name="wd", default=hp['weight_decay'],
        help=help_wrap("The amount of regularization"
            "If regularizer=None, the variable becomes useless"))
  flags.DEFINE_integer(
        name="model_id", short_name="mid", default=model_id,
        help=help_wrap("The index of model in the population"))
  flags.DEFINE_integer(
        name="total_epochs", short_name="ttep", default=train_epochs,
        help=help_wrap("The total epochs the model will be trained"))
  flags.DEFINE_integer(
        name="epoch_index", short_name="epi", default=epoch_index,
        help=help_wrap("The epoch index write to csv."))
  
  flags_core.set_defaults(data_dir=data_dir,
                          model_dir=model_dir,
                          resnet_size='50',
                          train_epochs=train_epochs,
                          epochs_between_evals=1,
                          batch_size=hp['batch_size'])
예제 #3
0
def define_cifar_flags():
    resnet_run_loop.define_resnet_flags()
    flags.adopt_module_key_flags(resnet_run_loop)
    flags_core.set_defaults(data_dir='/home/yotamg/data/rgb/train',
                            model_dir=RESOURCES_OUT_DIR,
                            resnet_size='32',
                            train_epochs=2500,
                            epochs_between_evals=10,
                            batch_size=128)
예제 #4
0
def define_cifar_flags():
    resnet_run_loop.define_resnet_flags()
    flags.adopt_module_key_flags(resnet_run_loop)
    flags_core.set_defaults(data_dir='',
                            model_dir='',
                            resnet_size='56',
                            train_epochs=182,
                            epochs_between_evals=5,
                            batch_size=128,
                            image_bytes_as_serving_input=False)
예제 #5
0
def define_imagenet_flags():
	resnet_run_loop.define_resnet_flags(
			resnet_size_choices=['18', '34', '50', '101', '152', '200'])
	flags.adopt_module_key_flags(resnet_run_loop)
	flags_core.set_defaults(train_epochs=90)