def imagenet_model_fn(features, labels, mode, params):
  """Our model_fn for ResNet to be used with our Estimator."""

  if params['flag_bias'] == True:
    ## lr for bias correction learning
    learning_rate_fn = resnet.learning_rate_with_decay(
        batch_size=params['batch_size'], batch_denom=256,
        num_images=params['num_train_images'],
        boundary_epochs=[30*_BIAS_EPOCHS, 60*_BIAS_EPOCHS, 80*_BIAS_EPOCHS, 90*_BIAS_EPOCHS],
        decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])
  else:
    ## normal train lr
    learning_rate_fn = resnet.learning_rate_with_decay(
        batch_size=params['batch_size'], batch_denom=256,
        num_images=params['num_train_images'], boundary_epochs=[30, 60, 80, 90],
        decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])

  ## increase the weight decay in initial stages
  ## This line of code is here and left unchanged for a long time. 
  weight_decay = 1e-4 * nb_groups / (params['itera'] + 1)
  return resnet.resnet_model_fn(features, labels, mode, ImagenetModel,
                                resnet_size=params['resnet_size'],
                                weight_decay=weight_decay,
                                learning_rate_fn=learning_rate_fn,
                                momentum=0.9,
                                data_format=params['data_format'],
                                itera=params['itera'],
                                nb_groups=params['nb_groups'],
                                restore_model_dir=params['restore_model_dir'],
                                flag_bias=params['flag_bias'],
                                loss_filter_fn=None)
Пример #2
0
def cifar10_model_fn(features, labels, mode, params):
  """Model function for CIFAR-10."""
  features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS])

  learning_rate_fn = resnet.learning_rate_with_decay(
      batch_size=params['batch_size'], batch_denom=128,
      num_images=_NUM_IMAGES['train'], boundary_epochs=[100, 150, 200],
      decay_rates=[1, 0.1, 0.01, 0.001])

  # We use a weight decay of 0.0002, which performs better
  # than the 0.0001 that was originally suggested.
  weight_decay = 2e-4

  # Empirical testing showed that including batch_normalization variables
  # in the calculation of regularized loss helped validation accuracy
  # for the CIFAR-10 dataset, perhaps because the regularization prevents
  # overfitting on the small data set. We therefore include all vars when
  # regularizing and computing loss during training.
  def loss_filter_fn(name):
    return True

  return resnet.resnet_model_fn(features, labels, mode, Cifar10Model,
                                resnet_size=params['resnet_size'],
                                weight_decay=weight_decay,
                                learning_rate_fn=learning_rate_fn,
                                momentum=0.9,
                                data_format=params['data_format'],
                                loss_filter_fn=loss_filter_fn)
Пример #3
0
def cifar10_model_fn(features, labels, mode, params):
    """Model function for CIFAR-10."""
    features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS])

    learning_rate_fn = resnet.learning_rate_with_decay(
        batch_size=params['batch_size'],
        batch_denom=128,
        num_images=_NUM_IMAGES['train'],
        boundary_epochs=[100, 150, 200],
        decay_rates=[1, 0.1, 0.01, 0.001])

    # We use a weight decay of 0.0002, which performs better
    # than the 0.0001 that was originally suggested.
    weight_decay = 2e-4

    # Empirical testing showed that including batch_normalization variables
    # in the calculation of regularized loss helped validation accuracy
    # for the CIFAR-10 dataset, perhaps because the regularization prevents
    # overfitting on the small data set. We therefore include all vars when
    # regularizing and computing loss during training.
    def loss_filter_fn(name):
        return True

    return resnet.resnet_model_fn(features,
                                  labels,
                                  mode,
                                  Cifar10Model,
                                  resnet_size=params['resnet_size'],
                                  weight_decay=weight_decay,
                                  learning_rate_fn=learning_rate_fn,
                                  momentum=0.9,
                                  data_format=params['data_format'],
                                  loss_filter_fn=loss_filter_fn,
                                  multi_gpu=params['multi_gpu'])
Пример #4
0
def imagenet_model_fn(features, labels, mode, params):
  """Our model_fn for ResNet to be used with our Estimator."""
  learning_rate_fn = resnet.learning_rate_with_decay(
      batch_size=params['batch_size'], batch_denom=256,
      num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
      decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])

  return resnet.resnet_model_fn(features, labels, mode, ImagenetModel,
                                resnet_size=params['resnet_size'],
                                weight_decay=1e-4,
                                learning_rate_fn=learning_rate_fn,
                                momentum=0.9,
                                data_format=params['data_format'],
                                loss_filter_fn=None)
Пример #5
0
def imagenet_model_fn(features, labels, mode, params):
    """Our model_fn for ResNet to be used with our Estimator."""
    learning_rate_fn = resnet.learning_rate_with_decay(
        batch_size=params['batch_size'],
        batch_denom=256,
        num_images=_NUM_IMAGES['train'],
        boundary_epochs=[30, 60, 80, 90],
        decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])

    return resnet.resnet_model_fn(features,
                                  labels,
                                  mode,
                                  ImagenetModel,
                                  resnet_size=params['resnet_size'],
                                  weight_decay=1e-4,
                                  learning_rate_fn=learning_rate_fn,
                                  momentum=0.9,
                                  data_format=params['data_format'],
                                  loss_filter_fn=None)
Пример #6
0
def cifar_model_fn(features, labels, mode, params):
    """
    Returns the cifar-10 residual network model
    :param features: the image data
    :param labels: the labels of images
    :param mode: the mode, training, evaluation or prediction
    :param params: the parameters to be passed to the model
    :return: a resnet model fn
    """
    features = tf.reshape(features, [-1, _WIDTH, _HEIGHT, _CHANNELS])

    # the function that decays learning rate overtime
    learning_rate_fn = resnet.learning_rate_with_decay(
        num_images=_NUM_IMAGES['train'],
        batch_size=params['batch_size'],
        boundary_epochs=[100, 150, 200],
        decay_rates=[1, 0.1, 0.01, 0.001])
    weight_decay = 2e-4
    return resnet_model.resnet_model_fn(features, labels, mode, Cifar10Model,
                                        params['resnet_size'], weight_decay,
                                        learning_rate_fn, 0.9,
                                        params['data_format'],
                                        params['multi_gpu'])