示例#1
0
def cifar10_model_fn(features, labels, mode, params):
    """Model function for CIFAR-10."""
    features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS])

    learning_rate_fn = resnet.learning_rate_with_decay(
        batch_size=params['batch_size'],
        batch_denom=128,
        num_images=_NUM_IMAGES['train'],
        boundary_epochs=[100, 150, 200],
        decay_rates=[1, 0.1, 0.01, 0.001])

    # We use a weight decay of 0.0002, which performs better
    # than the 0.0001 that was originally suggested.
    weight_decay = 2e-4

    # Empirical testing showed that including batch_normalization variables
    # in the calculation of regularized loss helped validation accuracy
    # for the CIFAR-10 dataset, perhaps because the regularization prevents
    # overfitting on the small data set. We therefore include all vars when
    # regularizing and computing loss during training.
    def loss_filter_fn(name):
        return True

    return resnet.resnet_model_fn(features,
                                  labels,
                                  mode,
                                  Cifar10Model,
                                  resnet_size=params['resnet_size'],
                                  weight_decay=weight_decay,
                                  learning_rate_fn=learning_rate_fn,
                                  momentum=0.9,
                                  data_format=params['data_format'],
                                  loss_filter_fn=loss_filter_fn,
                                  multi_gpu=params['multi_gpu'])
示例#2
0
def cifar10_model_fn(features, labels, mode, params):
  """Model function for CIFAR-10."""
  features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS])

  learning_rate_fn = resnet.learning_rate_with_decay(
      batch_size=params['batch_size'], batch_denom=128,
      num_images=_NUM_IMAGES['train'], boundary_epochs=[100, 150, 200],
      decay_rates=[1, 0.1, 0.01, 0.001])

  # We use a weight decay of 0.0002, which performs better
  # than the 0.0001 that was originally suggested.
  weight_decay = 2e-4

  # Empirical testing showed that including batch_normalization variables
  # in the calculation of regularized loss helped validation accuracy
  # for the CIFAR-10 dataset, perhaps because the regularization prevents
  # overfitting on the small data set. We therefore include all vars when
  # regularizing and computing loss during training.
  def loss_filter_fn(name):
    return True

  return resnet.resnet_model_fn(features, labels, mode, Cifar10Model,
                                resnet_size=params['resnet_size'],
                                weight_decay=weight_decay,
                                learning_rate_fn=learning_rate_fn,
                                momentum=0.9,
                                data_format=params['data_format'],
                                loss_filter_fn=loss_filter_fn)
    def cifar10_model_fn(features, labels, mode, params):
        """Model function for CIFAR-10."""
        features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS])
        schedule_fn = resnet.schedule_with_warm_restarts(
            batch_size=params['batch_size'],
            num_images=_NUM_IMAGES['train'],
            t_0=100,
            m_mul=0.0)  # no restarts

        # Empirical testing showed that including batch_normalization variables
        # in the calculation of regularized loss helped validation accuracy
        # for the CIFAR-10 dataset, perhaps because the regularization prevents
        # overfitting on the small data set. We therefore include all vars when
        # regularizing and computing loss during training.
        def loss_filter_fn(name):
            return True

        return resnet.resnet_model_fn(
            features,
            labels,
            mode,
            Cifar10Model,
            resnet_size=params['resnet_size'],
            weight_decay=weight_decay,
            initial_learning_rate=lr,
            schedule_fn=schedule_fn,
            momentum=0.9,
            data_format=params['data_format'],
            loss_filter_fn=loss_filter_fn,
            decouple_weight_decay=FLAGS.decouple_weight_decay,
            optimizer_base=optimizer_base)
def imagenet_model_fn(features, labels, mode, params):
  """Our model_fn for ResNet to be used with our Estimator."""

  if params['flag_bias'] == True:
    ## lr for bias correction learning
    learning_rate_fn = resnet.learning_rate_with_decay(
        batch_size=params['batch_size'], batch_denom=256,
        num_images=params['num_train_images'],
        boundary_epochs=[30*_BIAS_EPOCHS, 60*_BIAS_EPOCHS, 80*_BIAS_EPOCHS, 90*_BIAS_EPOCHS],
        decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])
  else:
    ## normal train lr
    learning_rate_fn = resnet.learning_rate_with_decay(
        batch_size=params['batch_size'], batch_denom=256,
        num_images=params['num_train_images'], boundary_epochs=[30, 60, 80, 90],
        decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])

  ## increase the weight decay in initial stages
  ## This line of code is here and left unchanged for a long time. 
  weight_decay = 1e-4 * nb_groups / (params['itera'] + 1)
  return resnet.resnet_model_fn(features, labels, mode, ImagenetModel,
                                resnet_size=params['resnet_size'],
                                weight_decay=weight_decay,
                                learning_rate_fn=learning_rate_fn,
                                momentum=0.9,
                                data_format=params['data_format'],
                                itera=params['itera'],
                                nb_groups=params['nb_groups'],
                                restore_model_dir=params['restore_model_dir'],
                                flag_bias=params['flag_bias'],
                                loss_filter_fn=None)
示例#5
0
def cifar10_model_fn(features, labels, mode, params):
  """Model function for CIFAR-10."""
  features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS])

  epochs_1 = params ['train_epochs_mentor']
  learning_rate_fn_mentor = resnet.learning_rate_with_decay_2(
      batch_size=params['batch_size'], batch_denom=128,
      num_images=_NUM_IMAGES['train'], 
      boundary_epochs=[ epochs_1 // 2,
                       3 * epochs_1 // 4,
                       7 * epochs_1 // 8],
      initial_learning_rate = params['initial_learning_rate_mentor'],
      decay_rates=[1, 0.1, 0.01, 0.001])

  epochs_2 = params['train_epochs_mentee']
  learning_rate_fn_mentee = resnet.learning_rate_with_decay_2(
      batch_size=params['batch_size'], batch_denom=128,
      num_images=_NUM_IMAGES['train'], 
      boundary_epochs=[ epochs_1 + epochs_2//4,
                       epochs_1 + 3*epochs_2//4,
                       epochs_1 + 7 * epochs_2//8],
      initial_learning_rate = params['initial_learning_rate_mentee'],
      decay_rates=[1, 0.1, 0.01, 0.001])
  
  epochs_3 = params['train_epochs_finetune']
  learning_rate_fn_finetune = resnet.learning_rate_with_decay_2(
      batch_size=params['batch_size'], batch_denom=128,
      num_images=_NUM_IMAGES['train'], 
      boundary_epochs=[epochs_1 + epochs_2 + epochs_3 //4,
                       epochs_1 + epochs_2 + 3*epochs_3//4,
                       epochs_1 + epochs_2 + 7*epochs_3//8],
      initial_learning_rate = params['initial_learning_rate_finetune'],
      decay_rates=[1, 0.1, 0.01, 0.001])

  # Empirical testing showed that including batch_normalization variables
  # in the calculation of regularized loss helped validation accuracy
  # for the CIFAR-10 dataset, perhaps because the regularization prevents
  # overfitting on the small data set. We therefore include all vars when
  # regularizing and computing loss during training.
  def loss_filter_fn(name):
    return True
  return resnet.resnet_model_fn(features, labels, mode, Cifar10Model,
                                resnet_size=params['resnet_size'],
                                learning_rate_fn_mentor=learning_rate_fn_mentor,
                                learning_rate_fn_mentee=learning_rate_fn_mentee,
                            learning_rate_fn_finetune=learning_rate_fn_finetune,                                
                                momentum=0.9,
                                temperature = params['temperature'],
                                num_probes = params['num_probes'],
                                distillation_coeff=params['distillation_coeff'],
                                weight_decay_coeff=params['weight_decay_coeff'],
                                probes_coeff = params['probes_coeff'],
                                optimizer = params['optimizer'],
                                trainee=params['trainee'],
                                data_format=params['data_format'],
                                pool_probes=params['pool_probes'],
                                pool_type=params['pool_type'],
                                loss_filter_fn=loss_filter_fn)
示例#6
0
def imagenet_model_fn(features, labels, mode, params):
  """Our model_fn for ResNet to be used with our Estimator."""
  learning_rate_fn = resnet.learning_rate_with_decay(
      batch_size=params['batch_size'], batch_denom=256,
      num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
      decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])

  return resnet.resnet_model_fn(features, labels, mode, ImagenetModel,
                                resnet_size=params['resnet_size'],
                                weight_decay=1e-4,
                                learning_rate_fn=learning_rate_fn,
                                momentum=0.9,
                                data_format=params['data_format'],
                                loss_filter_fn=None)
def imagenet_model_fn(features, labels, mode, params):
    """Our model_fn for ResNet to be used with our Estimator."""
    learning_rate_fn = resnet.learning_rate_with_decay(
        batch_size=params['batch_size'],
        batch_denom=256,
        num_images=_NUM_IMAGES['train'],
        boundary_epochs=[30, 60, 80, 90],
        decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])

    return resnet.resnet_model_fn(features,
                                  labels,
                                  mode,
                                  ImagenetModel,
                                  resnet_size=params['resnet_size'],
                                  weight_decay=1e-4,
                                  learning_rate_fn=learning_rate_fn,
                                  momentum=0.9,
                                  data_format=params['data_format'],
                                  loss_filter_fn=None)